From 27d9b01342f3956632521951be18bb487a95a352 Mon Sep 17 00:00:00 2001 From: timothijoe Date: Sat, 10 Jun 2023 11:33:40 +0800 Subject: [PATCH 01/28] add stochastic mz ptree --- lzero/entry/train_muzero.py | 4 +- lzero/mcts/buffer/__init__.py | 1 + .../buffer/game_buffer_stochastic_muzero.py | 696 ++++++++++++++ lzero/mcts/ptree/ptree_stochastic_mz.py | 618 +++++++++++++ lzero/mcts/tree_search/__init__.py | 1 + .../mcts/tree_search/mcts_ptree_stochastic.py | 244 +++++ lzero/model/stochastic_muzero_model.py | 873 ++++++++++++++++++ lzero/policy/stochastic_muzero.py | 738 +++++++++++++++ zoo/atari/config/atari_stochastic_muzero.py | 99 ++ 9 files changed, 3273 insertions(+), 1 deletion(-) create mode 100644 lzero/mcts/buffer/game_buffer_stochastic_muzero.py create mode 100644 lzero/mcts/ptree/ptree_stochastic_mz.py create mode 100644 lzero/mcts/tree_search/mcts_ptree_stochastic.py create mode 100644 lzero/model/stochastic_muzero_model.py create mode 100644 lzero/policy/stochastic_muzero.py create mode 100644 zoo/atari/config/atari_stochastic_muzero.py diff --git a/lzero/entry/train_muzero.py b/lzero/entry/train_muzero.py index 06ae053cc..8cac8c1f7 100644 --- a/lzero/entry/train_muzero.py +++ b/lzero/entry/train_muzero.py @@ -44,7 +44,7 @@ def train_muzero( """ cfg, create_cfg = input_cfg - assert create_cfg.policy.type in ['efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero'], \ + assert create_cfg.policy.type in ['efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero', 'stochastic_muzero'], \ "train_muzero entry now only support the following algo.: 'efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero'" if create_cfg.policy.type == 'muzero': @@ -55,6 +55,8 @@ def train_muzero( from lzero.mcts import SampledEfficientZeroGameBuffer as GameBuffer elif create_cfg.policy.type == 'gumbel_muzero': from lzero.mcts import GumbelMuZeroGameBuffer as GameBuffer + elif create_cfg.policy.type == 'stochastic_muzero': + from lzero.mcts import StochasticMuZeroGameBuffer as GameBuffer if cfg.policy.cuda and torch.cuda.is_available(): cfg.policy.device = 'cuda' diff --git a/lzero/mcts/buffer/__init__.py b/lzero/mcts/buffer/__init__.py index e7fdf4a97..31680a75e 100644 --- a/lzero/mcts/buffer/__init__.py +++ b/lzero/mcts/buffer/__init__.py @@ -2,3 +2,4 @@ from .game_buffer_efficientzero import EfficientZeroGameBuffer from .game_buffer_sampled_efficientzero import SampledEfficientZeroGameBuffer from .game_buffer_gumbel_muzero import GumbelMuZeroGameBuffer +from .game_buffer_stochastic_muzero import StochasticMuZeroGameBuffer diff --git a/lzero/mcts/buffer/game_buffer_stochastic_muzero.py b/lzero/mcts/buffer/game_buffer_stochastic_muzero.py new file mode 100644 index 000000000..6dfe6d8e2 --- /dev/null +++ b/lzero/mcts/buffer/game_buffer_stochastic_muzero.py @@ -0,0 +1,696 @@ +from typing import Any, List, Tuple, Union, TYPE_CHECKING, Optional + +import numpy as np +import torch +from ding.utils import BUFFER_REGISTRY + +from lzero.mcts.tree_search.mcts_ctree import MuZeroMCTSCtree as MCTSCtree +from lzero.mcts.tree_search.mcts_ptree_stochastic import StochasticMuZeroMCTSPtree as MCTSPtree +from lzero.mcts.utils import prepare_observation +from lzero.policy import to_detach_cpu_numpy, concat_output, concat_output_value, inverse_scalar_transform +from .game_buffer import GameBuffer + +if TYPE_CHECKING: + from lzero.policy import MuZeroPolicy, EfficientZeroPolicy, SampledEfficientZeroPolicy + + +@BUFFER_REGISTRY.register('game_buffer_stochastic_muzero') +class StochasticMuZeroGameBuffer(GameBuffer): + """ + Overview: + The specific game buffer for MuZero policy. + """ + + def __init__(self, cfg: dict): + super().__init__(cfg) + """ + Overview: + Use the default configuration mechanism. If a user passes in a cfg with a key that matches an existing key + in the default configuration, the user-provided value will override the default configuration. Otherwise, + the default configuration will be used. + """ + default_config = self.default_config() + default_config.update(cfg) + self._cfg = default_config + assert self._cfg.env_type in ['not_board_games', 'board_games'] + self.replay_buffer_size = self._cfg.replay_buffer_size + self.batch_size = self._cfg.batch_size + self._alpha = self._cfg.priority_prob_alpha + self._beta = self._cfg.priority_prob_beta + + self.keep_ratio = 1 + self.model_update_interval = 10 + self.num_of_collected_episodes = 0 + self.base_idx = 0 + self.clear_time = 0 + + self.game_segment_buffer = [] + self.game_pos_priorities = [] + self.game_segment_game_pos_look_up = [] + + def sample( + self, batch_size: int, policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"] + ) -> List[Any]: + """ + Overview: + sample data from ``GameBuffer`` and prepare the current and target batch for training. + Arguments: + - batch_size (:obj:`int`): batch size. + - policy (:obj:`Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"]`): policy. + Returns: + - train_data (:obj:`List`): List of train data, including current_batch and target_batch. + """ + policy._target_model.to(self._cfg.device) + policy._target_model.eval() + + # obtain the current_batch and prepare target context + reward_value_context, policy_re_context, policy_non_re_context, current_batch = self._make_batch( + batch_size, self._cfg.reanalyze_ratio + ) + # target reward, target value + batch_rewards, batch_target_values = self._compute_target_reward_value( + reward_value_context, policy._target_model + ) + # target policy + batch_target_policies_re = self._compute_target_policy_reanalyzed(policy_re_context, policy._target_model) + batch_target_policies_non_re = self._compute_target_policy_non_reanalyzed( + policy_non_re_context, self._cfg.model.action_space_size + ) + + # fusion of batch_target_policies_re and batch_target_policies_non_re to batch_target_policies + if 0 < self._cfg.reanalyze_ratio < 1: + batch_target_policies = np.concatenate([batch_target_policies_re, batch_target_policies_non_re]) + elif self._cfg.reanalyze_ratio == 1: + batch_target_policies = batch_target_policies_re + elif self._cfg.reanalyze_ratio == 0: + batch_target_policies = batch_target_policies_non_re + + target_batch = [batch_rewards, batch_target_values, batch_target_policies] + + # a batch contains the current_batch and the target_batch + train_data = [current_batch, target_batch] + return train_data + + def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: + """ + Overview: + first sample orig_data through ``_sample_orig_data()``, + then prepare the context of a batch: + reward_value_context: the context of reanalyzed value targets + policy_re_context: the context of reanalyzed policy targets + policy_non_re_context: the context of non-reanalyzed policy targets + current_batch: the inputs of batch + Arguments: + - batch_size (:obj:`int`): the batch size of orig_data from replay buffer. + - reanalyze_ratio (:obj:`float`): ratio of reanalyzed policy (value is 100% reanalyzed) + Returns: + - context (:obj:`Tuple`): reward_value_context, policy_re_context, policy_non_re_context, current_batch + """ + # obtain the batch context from replay buffer + orig_data = self._sample_orig_data(batch_size) + game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time_list = orig_data + batch_size = len(batch_index_list) + obs_list, action_list, mask_list = [], [], [] + # prepare the inputs of a batch + for i in range(batch_size): + game = game_segment_list[i] + pos_in_game_segment = pos_in_game_segment_list[i] + + actions_tmp = game.action_segment[pos_in_game_segment:pos_in_game_segment + + self._cfg.num_unroll_steps].tolist() + # add mask for invalid actions (out of trajectory) + mask_tmp = [1. for i in range(len(actions_tmp))] + mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps - len(mask_tmp))] + + # pad random action + actions_tmp += [ + np.random.randint(0, game.action_space_size) + for _ in range(self._cfg.num_unroll_steps - len(actions_tmp)) + ] + + # obtain the input observations + # pad if length of obs in game_segment is less than stack+num_unroll_steps + # e.g. stack+num_unroll_steps 4+5 + obs_list.append( + game_segment_list[i].get_unroll_obs( + pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True + ) + ) + action_list.append(actions_tmp) + mask_list.append(mask_tmp) + + # formalize the input observations + obs_list = prepare_observation(obs_list, self._cfg.model.model_type) + + # formalize the inputs of a batch + current_batch = [obs_list, action_list, mask_list, batch_index_list, weights_list, make_time_list] + for i in range(len(current_batch)): + current_batch[i] = np.asarray(current_batch[i]) + + total_transitions = self.get_num_of_transitions() + + # obtain the context of value targets + reward_value_context = self._prepare_reward_value_context( + batch_index_list, game_segment_list, pos_in_game_segment_list, total_transitions + ) + """ + only reanalyze recent reanalyze_ratio (e.g. 50%) data + if self._cfg.reanalyze_outdated is True, batch_index_list is sorted according to its generated env_steps + 0: reanalyze_num -> reanalyzed policy, reanalyze_num:end -> non reanalyzed policy + """ + reanalyze_num = int(batch_size * reanalyze_ratio) + # reanalyzed policy + if reanalyze_num > 0: + # obtain the context of reanalyzed policy targets + policy_re_context = self._prepare_policy_reanalyzed_context( + batch_index_list[:reanalyze_num], game_segment_list[:reanalyze_num], + pos_in_game_segment_list[:reanalyze_num] + ) + else: + policy_re_context = None + + # non reanalyzed policy + if reanalyze_num < batch_size: + # obtain the context of non-reanalyzed policy targets + policy_non_re_context = self._prepare_policy_non_reanalyzed_context( + batch_index_list[reanalyze_num:], game_segment_list[reanalyze_num:], + pos_in_game_segment_list[reanalyze_num:] + ) + else: + policy_non_re_context = None + + context = reward_value_context, policy_re_context, policy_non_re_context, current_batch + return context + + def _prepare_reward_value_context( + self, batch_index_list: List[str], game_segment_list: List[Any], pos_in_game_segment_list: List[Any], + total_transitions: int + ) -> List[Any]: + """ + Overview: + prepare the context of rewards and values for calculating TD value target in reanalyzing part. + Arguments: + - batch_index_list (:obj:`list`): the index of start transition of sampled minibatch in replay buffer + - game_segment_list (:obj:`list`): list of game segments + - pos_in_game_segment_list (:obj:`list`): list of transition index in game_segment + - total_transitions (:obj:`int`): number of collected transitions + Returns: + - reward_value_context (:obj:`list`): value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, + td_steps_list, action_mask_segment, to_play_segment + """ + zero_obs = game_segment_list[0].zero_obs() + value_obs_list = [] + # the value is valid or not (out of game_segment) + value_mask = [] + rewards_list = [] + game_segment_lens = [] + # for board games + action_mask_segment, to_play_segment = [], [] + + td_steps_list = [] + for game_segment, state_index, idx in zip(game_segment_list, pos_in_game_segment_list, batch_index_list): + game_segment_len = len(game_segment) + game_segment_lens.append(game_segment_len) + + td_steps = np.clip(self._cfg.td_steps, 1, max(1, game_segment_len - state_index)).astype(np.int32) + + # prepare the corresponding observations for bootstrapped values o_{t+k} + # o[t+ td_steps, t + td_steps + stack frames + num_unroll_steps] + # t=2+3 -> o[2+3, 2+3+4+5] -> o[5, 14] + game_obs = game_segment.get_unroll_obs(state_index + td_steps, self._cfg.num_unroll_steps) + + rewards_list.append(game_segment.reward_segment) + + # for board games + action_mask_segment.append(game_segment.action_mask_segment) + to_play_segment.append(game_segment.to_play_segment) + + for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): + # get the bootstrapped target obs + td_steps_list.append(td_steps) + # index of bootstrapped obs o_{t+td_steps} + bootstrap_index = current_index + td_steps + + if bootstrap_index < game_segment_len: + value_mask.append(1) + # beg_index = bootstrap_index - (state_index + td_steps), max of beg_index is num_unroll_steps + beg_index = current_index - state_index + end_index = beg_index + self._cfg.model.frame_stack_num + # the stacked obs in time t + obs = game_obs[beg_index:end_index] + else: + value_mask.append(0) + obs = zero_obs + + value_obs_list.append(obs) + + reward_value_context = [ + value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list, + action_mask_segment, to_play_segment + ] + return reward_value_context + + def _prepare_policy_non_reanalyzed_context( + self, batch_index_list: List[int], game_segment_list: List[Any], pos_in_game_segment_list: List[int] + ) -> List[Any]: + """ + Overview: + prepare the context of policies for calculating policy target in non-reanalyzing part, just return the policy in self-play + Arguments: + - batch_index_list (:obj:`list`): the index of start transition of sampled minibatch in replay buffer + - game_segment_list (:obj:`list`): list of game segments + - pos_in_game_segment_list (:obj:`list`): list transition index in game + Returns: + - policy_non_re_context (:obj:`list`): pos_in_game_segment_list, child_visits, game_segment_lens, action_mask_segment, to_play_segment + """ + child_visits = [] + game_segment_lens = [] + # for board games + action_mask_segment, to_play_segment = [], [] + + for game_segment, state_index, idx in zip(game_segment_list, pos_in_game_segment_list, batch_index_list): + game_segment_len = len(game_segment) + game_segment_lens.append(game_segment_len) + # for board games + action_mask_segment.append(game_segment.action_mask_segment) + to_play_segment.append(game_segment.to_play_segment) + + child_visits.append(game_segment.child_visit_segment) + + policy_non_re_context = [ + pos_in_game_segment_list, child_visits, game_segment_lens, action_mask_segment, to_play_segment + ] + return policy_non_re_context + + def _prepare_policy_reanalyzed_context( + self, batch_index_list: List[str], game_segment_list: List[Any], pos_in_game_segment_list: List[str] + ) -> List[Any]: + """ + Overview: + prepare the context of policies for calculating policy target in reanalyzing part. + Arguments: + - batch_index_list (:obj:'list'): start transition index in the replay buffer + - game_segment_list (:obj:'list'): list of game segments + - pos_in_game_segment_list (:obj:'list'): position of transition index in one game history + Returns: + - policy_re_context (:obj:`list`): policy_obs_list, policy_mask, pos_in_game_segment_list, indices, + child_visits, game_segment_lens, action_mask_segment, to_play_segment + """ + zero_obs = game_segment_list[0].zero_obs() + with torch.no_grad(): + # for policy + policy_obs_list = [] + policy_mask = [] + # 0 -> Invalid target policy for padding outside of game segments, + # 1 -> Previous target policy for game segments. + rewards, child_visits, game_segment_lens = [], [], [] + # for board games + action_mask_segment, to_play_segment = [], [] + for game_segment, state_index in zip(game_segment_list, pos_in_game_segment_list): + game_segment_len = len(game_segment) + game_segment_lens.append(game_segment_len) + rewards.append(game_segment.reward_segment) + # for board games + action_mask_segment.append(game_segment.action_mask_segment) + to_play_segment.append(game_segment.to_play_segment) + + child_visits.append(game_segment.child_visit_segment) + # prepare the corresponding observations + game_obs = game_segment.get_unroll_obs(state_index, self._cfg.num_unroll_steps) + for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): + + if current_index < game_segment_len: + policy_mask.append(1) + beg_index = current_index - state_index + end_index = beg_index + self._cfg.model.frame_stack_num + obs = game_obs[beg_index:end_index] + else: + policy_mask.append(0) + obs = zero_obs + policy_obs_list.append(obs) + + policy_re_context = [ + policy_obs_list, policy_mask, pos_in_game_segment_list, batch_index_list, child_visits, game_segment_lens, + action_mask_segment, to_play_segment + ] + return policy_re_context + + def _compute_target_reward_value(self, reward_value_context: List[Any], model: Any) -> Tuple[Any, Any]: + """ + Overview: + prepare reward and value targets from the context of rewards and values. + Arguments: + - reward_value_context (:obj:'list'): the reward value context + - model (:obj:'torch.tensor'):model of the target model + Returns: + - batch_value_prefixs (:obj:'np.ndarray): batch of value prefix + - batch_target_values (:obj:'np.ndarray): batch of value estimation + """ + value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list, action_mask_segment, \ + to_play_segment = reward_value_context # noqa + # transition_batch_size = game_segment_batch_size * (num_unroll_steps+1) + transition_batch_size = len(value_obs_list) + game_segment_batch_size = len(pos_in_game_segment_list) + + to_play, action_mask = self._preprocess_to_play_and_action_mask( + game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list + ) + if self._cfg.model.continuous_action_space is True: + # when the action space of the environment is continuous, action_mask[:] is None. + action_mask = [ + list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size) + ] + # NOTE: in continuous action space env: we set all legal_actions as -1 + legal_actions = [ + [-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size) + ] + else: + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)] + + batch_target_values, batch_rewards = [], [] + with torch.no_grad(): + value_obs_list = prepare_observation(value_obs_list, self._cfg.model.model_type) + # split a full batch into slices of mini_infer_size: to save the GPU memory for more GPU actors + slices = int(np.ceil(transition_batch_size / self._cfg.mini_infer_size)) + network_output = [] + for i in range(slices): + beg_index = self._cfg.mini_infer_size * i + end_index = self._cfg.mini_infer_size * (i + 1) + + m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device).float() + + # calculate the target value + m_output = model.initial_inference(m_obs) + + if not model.training: + # if not in training, obtain the scalars of the value/reward + [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( + [ + m_output.latent_state, + inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), + m_output.policy_logits + ] + ) + + network_output.append(m_output) + + # concat the output slices after model inference + if self._cfg.use_root_value: + # use the root values from MCTS, as in EfficiientZero + # the root values have limited improvement but require much more GPU actors; + _, reward_pool, policy_logits_pool, latent_state_roots = concat_output(network_output, data_type='muzero') + reward_pool = reward_pool.squeeze().tolist() + policy_logits_pool = policy_logits_pool.tolist() + noises = [ + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) + ).astype(np.float32).tolist() for j in range(transition_batch_size) + ] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(transition_batch_size, legal_actions) + roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) + # do MCTS for a new policy with the recent target model + MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play) + else: + # python mcts_tree + roots = MCTSPtree.roots(transition_batch_size, legal_actions) + roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) + # do MCTS for a new policy with the recent target model + MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play) + + roots_values = roots.get_values() + value_list = np.array(roots_values) + else: + # use the predicted values + value_list = concat_output_value(network_output) + + # get last state value + if self._cfg.env_type == 'board_games' and to_play_segment[0][0] in [1, 2]: + # TODO(pu): for board_games, very important, to check + value_list = value_list.reshape(-1) * np.array( + [ + self._cfg.discount_factor ** td_steps_list[i] if int(td_steps_list[i]) % + 2 == 0 else -self._cfg.discount_factor ** td_steps_list[i] + for i in range(transition_batch_size) + ] + ) + else: + value_list = value_list.reshape(-1) * ( + np.array([self._cfg.discount_factor for _ in range(transition_batch_size)]) ** td_steps_list + ) + + value_list = value_list * np.array(value_mask) + value_list = value_list.tolist() + horizon_id, value_index = 0, 0 + + for game_segment_len_non_re, reward_list, state_index, to_play_list in zip(game_segment_lens, rewards_list, + pos_in_game_segment_list, + to_play_segment): + target_values = [] + target_rewards = [] + base_index = state_index + for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): + bootstrap_index = current_index + td_steps_list[value_index] + # for i, reward in enumerate(game.rewards[current_index:bootstrap_index]): + for i, reward in enumerate(reward_list[current_index:bootstrap_index]): + if self._cfg.env_type == 'board_games' and to_play_segment[0][0] in [1, 2]: + # TODO(pu): for board_games, very important, to check + if to_play_list[base_index] == to_play_list[i]: + value_list[value_index] += reward * self._cfg.discount_factor ** i + else: + value_list[value_index] += -reward * self._cfg.discount_factor ** i + else: + value_list[value_index] += reward * self._cfg.discount_factor ** i + horizon_id += 1 + + if current_index < game_segment_len_non_re: + target_values.append(value_list[value_index]) + target_rewards.append(reward_list[current_index]) + else: + target_values.append(0) + target_rewards.append(0.0) + # TODO: check + # target_rewards.append(reward) + value_index += 1 + + batch_rewards.append(target_rewards) + batch_target_values.append(target_values) + + batch_rewards = np.asarray(batch_rewards, dtype=object) + batch_target_values = np.asarray(batch_target_values, dtype=object) + return batch_rewards, batch_target_values + + def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: Any) -> np.ndarray: + """ + Overview: + prepare policy targets from the reanalyzed context of policies + Arguments: + - policy_re_context (:obj:`List`): List of policy context to reanalyzed + Returns: + - batch_target_policies_re + """ + if policy_re_context is None: + return [] + batch_target_policies_re = [] + + # for board games + policy_obs_list, policy_mask, pos_in_game_segment_list, batch_index_list, child_visits, game_segment_lens, action_mask_segment, \ + to_play_segment = policy_re_context # noqa + # transition_batch_size = game_segment_batch_size * (self._cfg.num_unroll_steps + 1) + transition_batch_size = len(policy_obs_list) + game_segment_batch_size = len(pos_in_game_segment_list) + + to_play, action_mask = self._preprocess_to_play_and_action_mask( + game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list + ) + + if self._cfg.model.continuous_action_space is True: + # when the action space of the environment is continuous, action_mask[:] is None. + action_mask = [ + list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size) + ] + # NOTE: in continuous action space env: we set all legal_actions as -1 + legal_actions = [ + [-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size) + ] + else: + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)] + + with torch.no_grad(): + policy_obs_list = prepare_observation(policy_obs_list, self._cfg.model.model_type) + # split a full batch into slices of mini_infer_size: to save the GPU memory for more GPU actors + slices = int(np.ceil(transition_batch_size / self._cfg.mini_infer_size)) + network_output = [] + for i in range(slices): + beg_index = self._cfg.mini_infer_size * i + end_index = self._cfg.mini_infer_size * (i + 1) + m_obs = torch.from_numpy(policy_obs_list[beg_index:end_index]).to(self._cfg.device).float() + m_output = model.initial_inference(m_obs) + if not model.training: + # if not in training, obtain the scalars of the value/reward + [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( + [ + m_output.latent_state, + inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), + m_output.policy_logits + ] + ) + + network_output.append(m_output) + + _, reward_pool, policy_logits_pool, latent_state_roots = concat_output(network_output, data_type='muzero') + reward_pool = reward_pool.squeeze().tolist() + policy_logits_pool = policy_logits_pool.tolist() + noises = [ + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self._cfg.model.action_space_size + ).astype(np.float32).tolist() for _ in range(transition_batch_size) + ] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(transition_batch_size, legal_actions) + roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) + # do MCTS for a new policy with the recent target model + MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play) + else: + # python mcts_tree + roots = MCTSPtree.roots(transition_batch_size, legal_actions) + roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) + # do MCTS for a new policy with the recent target model + MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play) + + roots_legal_actions_list = legal_actions + roots_distributions = roots.get_distributions() + policy_index = 0 + for state_index, game_index in zip(pos_in_game_segment_list, batch_index_list): + target_policies = [] + + for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): + distributions = roots_distributions[policy_index] + + if policy_mask[policy_index] == 0: + # NOTE: the invalid padding target policy, O is to make sure the corresponding cross_entropy_loss=0 + target_policies.append([0 for _ in range(self._cfg.model.action_space_size)]) + else: + if distributions is None: + # if at some obs, the legal_action is None, add the fake target_policy + target_policies.append( + list(np.ones(self._cfg.model.action_space_size) / self._cfg.model.action_space_size) + ) + else: + if self._cfg.env_type == 'not_board_games': + # for atari/classic_control/box2d environments that only have one player. + sum_visits = sum(distributions) + policy = [visit_count / sum_visits for visit_count in distributions] + target_policies.append(policy) + else: + # for board games that have two players and legal_actions is dy + policy_tmp = [0 for _ in range(self._cfg.model.action_space_size)] + # to make sure target_policies have the same dimension + sum_visits = sum(distributions) + policy = [visit_count / sum_visits for visit_count in distributions] + for index, legal_action in enumerate(roots_legal_actions_list[policy_index]): + policy_tmp[legal_action] = policy[index] + target_policies.append(policy_tmp) + + policy_index += 1 + + batch_target_policies_re.append(target_policies) + + batch_target_policies_re = np.array(batch_target_policies_re) + + return batch_target_policies_re + + def _compute_target_policy_non_reanalyzed( + self, policy_non_re_context: List[Any], policy_shape: Optional[int] + ) -> np.ndarray: + """ + Overview: + prepare policy targets from the non-reanalyzed context of policies + Arguments: + - policy_non_re_context (:obj:`List`): List containing: + - pos_in_game_segment_list + - child_visits + - game_segment_lens + - action_mask_segment + - to_play_segment + - policy_shape: self._cfg.model.action_space_size + Returns: + - batch_target_policies_non_re + """ + batch_target_policies_non_re = [] + if policy_non_re_context is None: + return batch_target_policies_non_re + + pos_in_game_segment_list, child_visits, game_segment_lens, action_mask_segment, to_play_segment = policy_non_re_context + game_segment_batch_size = len(pos_in_game_segment_list) + transition_batch_size = game_segment_batch_size * (self._cfg.num_unroll_steps + 1) + + to_play, action_mask = self._preprocess_to_play_and_action_mask( + game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list + ) + + if self._cfg.model.continuous_action_space is True: + # when the action space of the environment is continuous, action_mask[:] is None. + action_mask = [ + list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size) + ] + # NOTE: in continuous action space env: we set all legal_actions as -1 + legal_actions = [ + [-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size) + ] + else: + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)] + + with torch.no_grad(): + policy_index = 0 + # 0 -> Invalid target policy for padding outside of game segments, + # 1 -> Previous target policy for game segments. + policy_mask = [] + for game_segment_len, child_visit, state_index in zip(game_segment_lens, child_visits, + pos_in_game_segment_list): + target_policies = [] + + for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): + if current_index < game_segment_len: + policy_mask.append(1) + # NOTE: child_visit is already a distribution + distributions = child_visit[current_index] + if self._cfg.env_type == 'not_board_games': + # for atari/classic_control/box2d environments that only have one player. + target_policies.append(distributions) + else: + # for board games that have two players. + policy_tmp = [0 for _ in range(policy_shape)] + for index, legal_action in enumerate(legal_actions[policy_index]): + # only the action in ``legal_action`` the policy logits is nonzero + policy_tmp[legal_action] = distributions[index] + target_policies.append(policy_tmp) + else: + # NOTE: the invalid padding target policy, O is to make sure the correspoding cross_entropy_loss=0 + policy_mask.append(0) + target_policies.append([0 for _ in range(policy_shape)]) + + policy_index += 1 + + batch_target_policies_non_re.append(target_policies) + batch_target_policies_non_re = np.asarray(batch_target_policies_non_re) + return batch_target_policies_non_re + + def update_priority(self, train_data: List[np.ndarray], batch_priorities: Any) -> None: + """ + Overview: + Update the priority of training data. + Arguments: + - train_data (:obj:`Optional[List[Optional[np.ndarray]]]`): training data to be updated priority. + - batch_priorities (:obj:`batch_priorities`): priorities to update to. + NOTE: + train_data = [current_batch, target_batch] + current_batch = [obs_list, action_list, mask_list, batch_index_list, weights, make_time_list] + """ + indices = train_data[0][3] + metas = {'make_time': train_data[0][5], 'batch_priorities': batch_priorities} + # only update the priorities for data still in replay buffer + for i in range(len(indices)): + if metas['make_time'][i] > self.clear_time: + idx, prio = indices[i], metas['batch_priorities'][i] + self.game_pos_priorities[idx] = prio diff --git a/lzero/mcts/ptree/ptree_stochastic_mz.py b/lzero/mcts/ptree/ptree_stochastic_mz.py new file mode 100644 index 000000000..5e17c8b32 --- /dev/null +++ b/lzero/mcts/ptree/ptree_stochastic_mz.py @@ -0,0 +1,618 @@ +""" +The Node, Roots class and related core functions for MuZero. +""" +import math +import random +from typing import List, Dict, Any, Tuple, Union + +import numpy as np +import torch + +from .minimax import MinMaxStats + + +class Node: + """ + Overview: + the node base class for MuZero. + Arguments: + """ + + def __init__(self, prior: float, legal_actions: List = None, action_space_size: int = 9, is_chance: bool = False, chance_space_size: int = 2) -> None: + self.prior = prior + self.legal_actions = legal_actions + self.action_space_size = action_space_size + + self.visit_count = 0 + self.value_sum = 0 + self.best_action = -1 + self.to_play = 0 # default 0 means play_with_bot_mode + self.reward = 0 + self.value_prefix = 0.0 + self.children = {} + self.children_index = [] + self.latent_state_index_in_search_path = 0 + self.latent_state_index_in_batch = 0 + self.parent_value_prefix = 0 # only used in update_tree_q method + + self.is_chance = is_chance + self.chance_space_size = chance_space_size + + def expand( + self, to_play: int, latent_state_index_in_search_path: int, latent_state_index_in_batch: int, reward: float, + policy_logits: List[float], child_is_chance: bool = False + ) -> None: + """ + Overview: + Expand the child nodes of the current node. + Arguments: + - to_play (:obj:`Class int`): which player to play the game in the current node. + - latent_state_index_in_search_path (:obj:`Class int`): the x/first index of latent state vector of the current node, i.e. the search depth. + - latent_state_index_in_batch (:obj:`Class int`): the y/second index of latent state vector of the current node, i.e. the index of batch root node, its maximum is ``batch_size``/``env_num``. + - value_prefix: (:obj:`Class float`): the value prefix of the current node. + - policy_logits: (:obj:`Class List`): the policy logit of the child nodes. + """ + self.to_play = to_play + self.reward = reward + + # assert (self.is_chance != child_is_chance), f"is_chance and child_is_chance should be different, current is {self.is_chance}-{child_is_chance}, " + + if self.is_chance is True: + child_is_chance = False + self.reward = 0.0 + + if self.legal_actions is None: + # self.legal_actions = np.arange(len(policy_logits)) + self.legal_actions = np.arange(self.chance_space_size) + self.latent_state_index_in_search_path = latent_state_index_in_search_path + self.latent_state_index_in_batch = latent_state_index_in_batch + policy_values = torch.softmax(torch.tensor([policy_logits[a] for a in self.legal_actions]), dim=0).tolist() + policy = {legal_action: policy_values[index] for index, legal_action in enumerate(self.legal_actions)} + for action, prior in policy.items(): + self.children[action] = Node(prior, is_chance=child_is_chance) + else: + child_is_chance = True + #self.legal_actions = np.arange(self.chance_space_size) + self.legal_actions = np.arange(len(policy_logits)) + self.latent_state_index_in_search_path = latent_state_index_in_search_path + self.latent_state_index_in_batch = latent_state_index_in_batch + policy_values = torch.softmax(torch.tensor([policy_logits[a] for a in self.legal_actions]), dim=0).tolist() + policy = {legal_action: policy_values[index] for index, legal_action in enumerate(self.legal_actions)} + for action, prior in policy.items(): + self.children[action] = Node(prior, is_chance=child_is_chance) + + def add_exploration_noise(self, exploration_fraction: float, noises: List[float]) -> None: + """ + Overview: + add exploration noise to priors + Arguments: + - noises (:obj: list): length is len(self.legal_actions) + """ + for i, a in enumerate(self.legal_actions): + """ + i in index, a is action, e.g. self.legal_actions = [0,1,2,4,6,8], i=[0,1,2,3,4,5], a=[0,1,2,4,6,8] + """ + try: + noise = noises[i] + except Exception as error: + print(error) + child = self.get_child(a) + prior = child.prior + child.prior = prior * (1 - exploration_fraction) + noise * exploration_fraction + + def compute_mean_q(self, is_root: int, parent_q: float, discount_factor: float) -> float: + """ + Overview: + Compute the mean q value of the current node. + Arguments: + - is_root (:obj:`int`): whether the current node is a root node. + - parent_q (:obj:`float`): the q value of the parent node. + - discount_factor (:obj:`float`): the discount_factor of reward. + """ + total_unsigned_q = 0.0 + total_visits = 0 + for a in self.legal_actions: + child = self.get_child(a) + if child.visit_count > 0: + true_reward = child.reward + # TODO(pu): only one step bootstrap? + q_of_s_a = true_reward + discount_factor * child.value + total_unsigned_q += q_of_s_a + total_visits += 1 + if is_root and total_visits > 0: + mean_q = total_unsigned_q / total_visits + else: + # if is not root node, + # TODO(pu): why parent_q? + mean_q = (parent_q + total_unsigned_q) / (total_visits + 1) + return mean_q + + def get_trajectory(self) -> List[Union[int, float]]: + """ + Overview: + Find the current best trajectory starts from the current node. + Outputs: + - traj: a vector of node index, which is the current best trajectory from this node. + """ + # TODO(pu): best action + traj = [] + node = self + best_action = node.best_action + while best_action >= 0: + traj.append(best_action) + + node = node.get_child(best_action) + best_action = node.best_action + return traj + + def get_children_distribution(self) -> List[Union[int, float]]: + if self.legal_actions == []: + return None + distribution = {a: 0 for a in self.legal_actions} + if self.expanded: + for a in self.legal_actions: + child = self.get_child(a) + distribution[a] = child.visit_count + # only take the visit counts + distribution = [v for k, v in distribution.items()] + return distribution + + def get_child(self, action: Union[int, float]) -> "Node": + """ + Overview: + get children node according to the input action. + """ + if not isinstance(action, np.int64): + action = int(action) + return self.children[action] + + @property + def expanded(self) -> bool: + return len(self.children) > 0 + + @property + def value(self) -> float: + """ + Overview: + Return the estimated value of the current root node. + """ + if self.visit_count == 0: + return 0 + else: + return self.value_sum / self.visit_count + + +class Roots: + + def __init__(self, root_num: int, legal_actions_list: List) -> None: + self.num = root_num + self.root_num = root_num + self.legal_actions_list = legal_actions_list # list of list + + self.roots = [] + for i in range(self.root_num): + if isinstance(legal_actions_list, list): + self.roots.append(Node(0, legal_actions_list[i])) + else: + # if legal_actions_list is int + self.roots.append(Node(0, np.arange(legal_actions_list))) + + def prepare( + self, + root_noise_weight: float, + noises: List[float], + rewards: List[float], + policies: List[List[float]], + to_play: int = -1 + ) -> None: + """ + Overview: + Expand the roots and add noises. + Arguments: + - root_noise_weight: the exploration fraction of roots + - noises: the vector of noise add to the roots. + - rewards: the vector of rewards of each root. + - policies: the vector of policy logits of each root. + - to_play_batch: the vector of the player side of each root. + """ + for i in range(self.root_num): + # to_play: int, latent_state_index_in_search_path: int, latent_state_index_in_batch: int, + # TODO(pu): why latent_state_index_in_search_path=0, latent_state_index_in_batch=i? + if to_play is None: + self.roots[i].expand(-1, 0, i, rewards[i], policies[i], child_is_chance=True) + else: + self.roots[i].expand(to_play[i], 0, i, rewards[i], policies[i], child_is_chance=True) + + self.roots[i].add_exploration_noise(root_noise_weight, noises[i]) + self.roots[i].visit_count += 1 + + def prepare_no_noise(self, rewards: List[float], policies: List[List[float]], to_play: int = -1) -> None: + """ + Overview: + Expand the roots without noise. + Arguments: + - rewards: the vector of rewards of each root. + - policies: the vector of policy logits of each root. + - to_play_batch: the vector of the player side of each root. + """ + for i in range(self.root_num): + if to_play is None: + self.roots[i].expand(-1, 0, i, rewards[i], policies[i], child_is_chance=True) + else: + self.roots[i].expand(to_play[i], 0, i, rewards[i], policies[i], child_is_chance=True) + + self.roots[i].visit_count += 1 + + def clear(self) -> None: + self.roots.clear() + + def get_trajectories(self) -> List[List[Union[int, float]]]: + """ + Overview: + Find the current best trajectory starts from each root. + Outputs: + - traj: a vector of node index, which is the current best trajectory from each root. + """ + trajs = [] + for i in range(self.root_num): + trajs.append(self.roots[i].get_trajectory()) + return trajs + + def get_distributions(self) -> List[List[Union[int, float]]]: + """ + Overview: + Get the children distribution of each root. + Outputs: + - distribution: a vector of distribution of child nodes in the format of visit count (i.e. [1,3,0,2,5]). + """ + distributions = [] + for i in range(self.root_num): + distributions.append(self.roots[i].get_children_distribution()) + + return distributions + + def get_values(self) -> float: + """ + Overview: + Return the estimated value of each root. + """ + values = [] + for i in range(self.root_num): + values.append(self.roots[i].value) + return values + + +class SearchResults: + + def __init__(self, num: int) -> None: + self.num = num + self.nodes = [] + self.search_paths = [] + self.latent_state_index_in_search_path = [] + self.latent_state_index_in_batch = [] + self.last_actions = [] + self.search_lens = [] + + +def update_tree_q(root: Node, min_max_stats: MinMaxStats, discount_factor: float, players: int = 1) -> None: + """ + Overview: + Update the value sum and visit count of nodes along the search path. + Arguments: + - search_path: a vector of nodes on the search path. + - min_max_stats: a tool used to min-max normalize the q value. + - to_play: which player to play the game in the current node. + - value: the value to propagate along the search path. + - discount_factor: the discount factor of reward. + """ + node_stack = [] + node_stack.append(root) + while len(node_stack) > 0: + node = node_stack[-1] + node_stack.pop() + + if node != root: + true_reward = node.reward + if players == 1: + q_of_s_a = true_reward + discount_factor * node.value + elif players == 2: + q_of_s_a = true_reward + discount_factor * (-node.value) + + min_max_stats.update(q_of_s_a) + + for a in node.legal_actions: + child = node.get_child(a) + if child.expanded: + node_stack.append(child) + + +def select_child( + node: Node, min_max_stats: MinMaxStats, pb_c_base: float, pb_c_int: float, discount_factor: float, + mean_q: float, players: int +) -> Union[int, float]: + """ + Overview: + Select the child node of the roots according to ucb scores. + Arguments: + - node: the node to select the child node. + - min_max_stats (:obj:`Class MinMaxStats`): a tool used to min-max normalize the score. + - pb_c_base (:obj:`Class Float`): constant c1 used in pUCT rule, typically 1.25. + - pb_c_int (:obj:`Class Float`): constant c2 used in pUCT rule, typically 19652. + - discount_factor (:obj:`Class Float`): discount_factor factor used i calculating bootstrapped value, if env is board_games, we set discount_factor=1. + - mean_q (:obj:`Class Float`): the mean q value of the parent node. + - players (:obj:`Class Int`): the number of players. one/two_player mode board games. + Returns: + - action (:obj:`Union[int, float]`): Choose the action with the highest ucb score. + """ + + if node.is_chance: + # print("root->is_chance: True ") + + # If the node is chance node, we sample from the prior outcome distribution. + outcomes, probs = zip(*[(o, n.prior) for o, n in node.children.items()]) + outcome = np.random.choice(outcomes, p=probs) + # print(outcome, probs) + return outcome + + # print("root->is_chance: False ") + # If the node is decision node, we select the action with the highest ucb score. + max_score = -np.inf + epsilon = 0.000001 + max_index_lst = [] + for a in node.legal_actions: + child = node.get_child(a) + temp_score = compute_ucb_score( + child, min_max_stats, mean_q, node.visit_count, pb_c_base, pb_c_int, discount_factor, players + ) + if max_score < temp_score: + max_score = temp_score + max_index_lst.clear() + max_index_lst.append(a) + elif temp_score >= max_score - epsilon: + # TODO(pu): if the difference is less than epsilon = 0.000001, we random choice action from max_index_lst + max_index_lst.append(a) + + action = 0 + if len(max_index_lst) > 0: + action = random.choice(max_index_lst) + return action + + +def compute_ucb_score( + child: Node, + min_max_stats: MinMaxStats, + parent_mean_q: float, + total_children_visit_counts: float, + pb_c_base: float, + pb_c_init: float, + discount_factor: float, + players: int = 1, +) -> float: + """ + Overview: + Compute the ucb score of the child. + Arguments: + - child: the child node to compute ucb score. + - min_max_stats: a tool used to min-max normalize the score. + - parent_mean_q: the mean q value of the parent node. + - is_reset: whether the value prefix needs to be reset. + - total_children_visit_counts: the total visit counts of the child nodes of the parent node. + - parent_value_prefix: the value prefix of parent node. + - pb_c_base: constants c2 in muzero. + - pb_c_init: constants c1 in muzero. + - disount_factor: the discount factor of reward. + - players: the number of players. + - continuous_action_space: whether the action space is continous in current env. + Outputs: + - ucb_value: the ucb score of the child. + """ + pb_c = math.log((total_children_visit_counts + pb_c_base + 1) / pb_c_base) + pb_c_init + pb_c *= (math.sqrt(total_children_visit_counts) / (child.visit_count + 1)) + + prior_score = pb_c * child.prior + if child.visit_count == 0: + value_score = parent_mean_q + else: + true_reward = child.reward + if players == 1: + value_score = true_reward + discount_factor * child.value + elif players == 2: + value_score = true_reward + discount_factor * (-child.value) + + value_score = min_max_stats.normalize(value_score) + if value_score < 0: + value_score = 0 + if value_score > 1: + value_score = 1 + ucb_score = prior_score + value_score + + return ucb_score + + +def batch_traverse( + roots: Any, + pb_c_base: float, + pb_c_init: float, + discount_factor: float, + min_max_stats_lst: List[MinMaxStats], + results: SearchResults, + virtual_to_play: List, +) -> Tuple[Any, Any]: + + """ + Overview: + traverse, also called selection. process a batch roots parallely. + Arguments: + - roots (:obj:`Any`): a batch of root nodes to be expanded. + - pb_c_base (:obj:`float`): constant c1 used in pUCT rule, typically 1.25. + - pb_c_init (:obj:`float`): constant c2 used in pUCT rule, typically 19652. + - discount_factor (:obj:`float`): discount_factor factor used i calculating bootstrapped value, if env is board_games, we set discount_factor=1. + - virtual_to_play (:obj:`list`): the to_play list used in self_play collecting and training in board games, + `virtual` is to emphasize that actions are performed on an imaginary hidden state. + - continuous_action_space: whether the action space is continous in current env. + Returns: + - latent_state_index_in_search_path (:obj:`list`): the list of x/first index of latent state vector of the searched node, i.e. the search depth. + - latent_state_index_in_batch (:obj:`list`): the list of y/second index of latent state vector of the searched node, i.e. the index of batch root node, its maximum is ``batch_size``/``env_num``. + - last_actions (:obj:`list`): the action performed by the previous node. + - virtual_to_play (:obj:`list`): the to_play list used in self_play collecting and trainin gin board games, + `virtual` is to emphasize that actions are performed on an imaginary hidden state. + """ + parent_q = 0.0 + results.search_lens = [None for i in range(results.num)] + results.last_actions = [None for i in range(results.num)] + + results.nodes = [None for i in range(results.num)] + results.latent_state_index_in_search_path = [None for i in range(results.num)] + results.latent_state_index_in_batch = [None for i in range(results.num)] + if virtual_to_play in [1, 2] or virtual_to_play[0] in [1, 2]: + players = 2 + elif virtual_to_play in [-1, None] or virtual_to_play[0] in [-1, None]: + players = 1 + + results.search_paths = {i: [] for i in range(results.num)} + for i in range(results.num): + node = roots.roots[i] + is_root = 1 + search_len = 0 + results.search_paths[i].append(node) + + """ + MCTS stage 1: Selection + Each simulation starts from the internal root state s0, and finishes when the simulation reaches a leaf node s_l. + """ + # the leaf node is not expanded + while node.expanded: + mean_q = node.compute_mean_q(is_root, parent_q, discount_factor) + is_root = 0 + parent_q = mean_q + + # select action according to the pUCT rule. + action = select_child( + node, min_max_stats_lst.stats_lst[i], pb_c_base, pb_c_init, discount_factor, mean_q, players + ) + if players == 2: + # Players play turn by turn + if virtual_to_play[i] == 1: + virtual_to_play[i] = 2 + else: + virtual_to_play[i] = 1 + + node.best_action = action + # move to child node according to selected action. + node = node.get_child(action) + + last_action = action + + results.search_paths[i].append(node) + search_len += 1 + + # note this return the parent node of the current searched node + parent = results.search_paths[i][len(results.search_paths[i]) - 1 - 1] + results.latent_state_index_in_search_path[i] = parent.latent_state_index_in_search_path + results.latent_state_index_in_batch[i] = parent.latent_state_index_in_batch + results.last_actions[i] = last_action + results.search_lens[i] = search_len + # while we break out the while loop, results.nodes[i] save the leaf node. + results.nodes[i] = node + + # print(f'env {i} one simulation done!') + # return results.nodes, results.latent_state_index_in_search_path, results.latent_state_index_in_batch, results.last_actions, virtual_to_play + return results, virtual_to_play + + +def backpropagate( + search_path: List[Node], min_max_stats: MinMaxStats, to_play: int, value: float, discount_factor: float +) -> None: + """ + Overview: + Update the value sum and visit count of nodes along the search path. + Arguments: + - search_path: a vector of nodes on the search path. + - min_max_stats: a tool used to min-max normalize the q value. + - to_play: which player to play the game in the current node. + - value: the value to propagate along the search path. + - discount_factor: the discount factor of reward. + """ + assert to_play is None or to_play in [-1, 1, 2] + if to_play is None or to_play == -1: + # for play-with-bot mode + bootstrap_value = value + path_len = len(search_path) + for i in range(path_len - 1, -1, -1): + node = search_path[i] + node.value_sum += bootstrap_value + node.visit_count += 1 + true_reward = node.reward + min_max_stats.update(true_reward + discount_factor * node.value) + bootstrap_value = true_reward + discount_factor * bootstrap_value + else: + # for self-play-mode + bootstrap_value = value + path_len = len(search_path) + for i in range(path_len - 1, -1, -1): + node = search_path[i] + # to_play related + node.value_sum += bootstrap_value if node.to_play == to_play else -bootstrap_value + + node.visit_count += 1 + + # NOTE: in self-play-mode, + # we should calculate the true_reward according to the perspective of current player of node + # true_reward = node.value_prefix - (- parent_value_prefix) + true_reward = node.reward + + # min_max_stats.update(true_reward + discount_factor * node.value) + min_max_stats.update(true_reward + discount_factor * -node.value) + + # TODO(pu): to_play related + # true_reward is in the perspective of current player of node + # bootstrap_value = (true_reward if node.to_play == to_play else - true_reward) + discount_factor * bootstrap_value + bootstrap_value = ( + -true_reward if node.to_play == to_play else true_reward + ) + discount_factor * bootstrap_value + + +def batch_backpropagate( + latent_state_index_in_search_path: int, + discount_factor: float, + value_prefixs: List[float], + values: List[float], + policies: List[float], + min_max_stats_lst: List[MinMaxStats], + results: SearchResults, + to_play: list = None, + is_chance_list: list = None, + leaf_idx_list: list = None, +) -> None: + """ + Overview: + Backpropagation along the search path to update the attributes. + Arguments: + - latent_state_index_in_search_path (:obj:`Class Int`): the index of latent state vector. + - discount_factor (:obj:`Class Float`): discount_factor factor used i calculating bootstrapped value, if env is board_games, we set discount_factor=1. + - value_prefixs (:obj:`Class List`): the value prefixs of nodes along the search path. + - values (:obj:`Class List`): the values to propagate along the search path. + - policies (:obj:`Class List`): the policy logits of nodes along the search path. + - min_max_stats_lst (:obj:`Class List[MinMaxStats]`): a tool used to min-max normalize the q value. + - results (:obj:`Class List`): the search results. + - to_play (:obj:`Class List`): the batch of which player is playing on this node. + """ + if leaf_idx_list is None: + leaf_idx_list = list(range(results.num)) + # for i in range(results.num): + # for i in leaf_idx_list: + for leaf_order, i in enumerate(leaf_idx_list): + # ****** expand the leaf node ****** + if to_play is None: + # set to_play=-1, because two_player mode to_play = {1,2} + results.nodes[i].expand(-1, latent_state_index_in_search_path, i, value_prefixs[leaf_order], policies[leaf_order], is_chance_list[i]) + else: + results.nodes[i].expand(to_play[i], latent_state_index_in_search_path, i, value_prefixs[leaf_order], policies[leaf_order], is_chance_list[i]) + + # ****** backpropagate ****** + if to_play is None: + backpropagate(results.search_paths[i], min_max_stats_lst.stats_lst[i], 0, values[leaf_order], discount_factor) + else: + backpropagate( + results.search_paths[i], min_max_stats_lst.stats_lst[i], to_play[i], values[leaf_order], discount_factor + ) diff --git a/lzero/mcts/tree_search/__init__.py b/lzero/mcts/tree_search/__init__.py index c89da9a92..581a67844 100644 --- a/lzero/mcts/tree_search/__init__.py +++ b/lzero/mcts/tree_search/__init__.py @@ -2,3 +2,4 @@ from .mcts_ctree_sampled import SampledEfficientZeroMCTSCtree from .mcts_ptree import MuZeroMCTSPtree, EfficientZeroMCTSPtree from .mcts_ptree_sampled import SampledEfficientZeroMCTSPtree +from .mcts_ptree_stochastic import StochasticMuZeroMCTSPtree diff --git a/lzero/mcts/tree_search/mcts_ptree_stochastic.py b/lzero/mcts/tree_search/mcts_ptree_stochastic.py new file mode 100644 index 000000000..23f2b35e1 --- /dev/null +++ b/lzero/mcts/tree_search/mcts_ptree_stochastic.py @@ -0,0 +1,244 @@ +from typing import TYPE_CHECKING, List, Any, Union +from easydict import EasyDict + +import copy +import numpy as np +import torch + +from lzero.mcts.ptree import MinMaxStatsList +from lzero.policy import InverseScalarTransform +import lzero.mcts.ptree.ptree_stochastic_mz as tree_muzero + +if TYPE_CHECKING: + import lzero.mcts.ptree.ptree_stochastic_mz as stochastic_mz_ptree + + +# ============================================================== +# Stochastic MuZero +# ============================================================== + + +class StochasticMuZeroMCTSPtree(object): + """ + Overview: + MCTSPtree for MuZero. The core ``batch_traverse`` and ``batch_backpropagate`` function is implemented in python. + Interfaces: + __init__, search + """ + + # the default_config for MuZeroMCTSPtree. + config = dict( + # (float) The alpha value used in the Dirichlet distribution for exploration at the root node of the search tree. + root_dirichlet_alpha=0.3, + # (float) The noise weight at the root node of the search tree. + root_noise_weight=0.25, + # (int) The base constant used in the PUCT formula for balancing exploration and exploitation during tree search. + pb_c_base=19652, + # (float) The initialization constant used in the PUCT formula for balancing exploration and exploitation during tree search. + pb_c_init=1.25, + # (float) The maximum change in value allowed during the backup step of the search tree update. + value_delta_max=0.01, + ) + + @classmethod + def default_config(cls: type) -> EasyDict: + cfg = EasyDict(copy.deepcopy(cls.config)) + cfg.cfg_type = cls.__name__ + 'Dict' + return cfg + + def __init__(self, cfg: EasyDict = None) -> None: + """ + Overview: + Use the default configuration mechanism. If a user passes in a cfg with a key that matches an existing key + in the default configuration, the user-provided value will override the default configuration. Otherwise, + the default configuration will be used. + """ + default_config = self.default_config() + default_config.update(cfg) + self._cfg = default_config + self.inverse_scalar_transform_handle = InverseScalarTransform( + self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution + ) + + @classmethod + def roots(cls: int, root_num: int, legal_actions: List[Any]) -> "stochastic_mz_ptree.Roots": + """ + Overview: + The initialization of CRoots with root num and legal action lists. + Arguments: + - root_num: the number of the current root. + - legal_action_list: the vector of the legal action of this root. + """ + import lzero.mcts.ptree.ptree_stochastic_mz as ptree + return ptree.Roots(root_num, legal_actions) + + def search( + self, + roots: Any, + model: torch.nn.Module, + latent_state_roots: List[Any], + to_play: Union[int, List[Any]] = -1 + ) -> None: + """ + Overview: + Do MCTS for the roots (a batch of root nodes in parallel). Parallel in model inference. + Use the python ctree. + Arguments: + - roots (:obj:`Any`): a batch of expanded root nodes + - latent_state_roots (:obj:`list`): the hidden states of the roots + - to_play (:obj:`list`): the to_play list used in two_player mode board games + """ + with torch.no_grad(): + model.eval() + + # preparation + num = roots.num + device = self._cfg.device + pb_c_base, pb_c_init, discount_factor = self._cfg.pb_c_base, self._cfg.pb_c_init, self._cfg.discount_factor + # the data storage of hidden states: storing the hidden states of all the ctree root nodes + # latent_state_roots.shape (2, 12, 3, 3) + latent_state_batch_in_search_path = [latent_state_roots] + + # the index of each layer in the ctree + current_latent_state_index = 0 + # minimax value storage + min_max_stats_lst = MinMaxStatsList(num) + + for simulation_index in range(self._cfg.num_simulations): + # In each simulation, we expanded a new node, so in one search, we have ``num_simulations`` num of nodes at most. + + latent_states = [] + + # prepare a result wrapper to transport results between python and c++ parts + results = tree_muzero.SearchResults(num=num) + + # latent_state_index_in_search_path: The first index of the latent state corresponding to the leaf node in latent_state_batch_in_search_path, that is, the search depth. + # latent_state_index_in_batch: The second index of the latent state corresponding to the leaf node in latent_state_batch_in_search_path, i.e. the index in the batch, whose maximum is ``batch_size``. + # e.g. the latent state of the leaf node in (x, y) is latent_state_batch_in_search_path[x, y], where x is current_latent_state_index, y is batch_index. + """ + MCTS stage 1: Selection + Each simulation starts from the internal root state s0, and finishes when the simulation reaches a leaf node s_l. + """ + # leaf_nodes, latent_state_index_in_search_path, latent_state_index_in_batch, last_actions, virtual_to_play = tree_muzero.batch_traverse( + # roots, pb_c_base, pb_c_init, discount_factor, min_max_stats_lst, results, copy.deepcopy(to_play) + # ) + results, virtual_to_play = tree_muzero.batch_traverse( + roots, pb_c_base, pb_c_init, discount_factor, min_max_stats_lst, results, copy.deepcopy(to_play) + ) + leaf_nodes, latent_state_index_in_search_path, latent_state_index_in_batch, last_actions = results.nodes, results.latent_state_index_in_search_path, results.latent_state_index_in_batch, results.last_actions + + # obtain the states for leaf nodes + for ix, iy in zip(latent_state_index_in_search_path, latent_state_index_in_batch): + latent_states.append( + latent_state_batch_in_search_path[ix][iy]) # latent_state_batch_in_search_path[ix][iy] shape e.g. (64,4,4) + + latent_states = torch.from_numpy(np.asarray(latent_states)).to(device).float() + # only for discrete action + last_actions = torch.from_numpy(np.asarray(last_actions)).to(device).long() + """ + MCTS stage 2: Expansion + At the final time-step l of the simulation, the next_latent_state and reward/value_prefix are computed by the dynamics function. + Then we calculate the policy_logits and value for the leaf node (next_latent_state) by the prediction function. (aka. evaluation) + MCTS stage 3: Backup + At the end of the simulation, the statistics along the trajectory are updated. + """ + # network_output = model.recurrent_inference(latent_states, last_actions) + + num = len(leaf_nodes) + latent_state_batch = [None] * num + value_batch = [None] * num + reward_batch = [None] * num + policy_logits_batch = [None] * num + child_is_chance_batch = [None] * num + chance_nodes = [] + decision_nodes = [] + for i, node in enumerate(leaf_nodes): + if node.is_chance: + chance_nodes.append(i) + else: + decision_nodes.append(i) + + def process_nodes(node_indices, is_chance): + # Return early if node_indices is empty + if not node_indices: + return + # Slice and stack latent_states and last_actions based on node_indices + latent_states_stack = torch.stack([latent_states[i] for i in node_indices], dim=0) + last_actions_stack = torch.stack([last_actions[i] for i in node_indices], dim=0) + + # Pass the stacked batch through the recurrent_inference function + network_output_batch = model.recurrent_inference(latent_states_stack, + last_actions_stack, + afterstate=not is_chance) + + # Split the batch output into separate nodes + latent_state_splits = torch.split(network_output_batch.latent_state, 1, dim=0) + value_splits = torch.split(network_output_batch.value, 1, dim=0) + reward_splits = torch.split(network_output_batch.reward, 1, dim=0) + policy_logits_splits = torch.split(network_output_batch.policy_logits, 1, dim=0) + + for i, (latent_state, value, reward, policy_logits) in zip(node_indices, + zip(latent_state_splits, value_splits, + reward_splits, + policy_logits_splits)): + if not model.training: + value = self.inverse_scalar_transform_handle(value).detach().cpu().numpy() + reward = self.inverse_scalar_transform_handle(reward).detach().cpu().numpy() + latent_state = latent_state.detach().cpu().numpy() + policy_logits = policy_logits.detach().cpu().numpy() + + latent_state_batch[i] = latent_state + value_batch[i] = value.reshape(-1).tolist() + reward_batch[i] = reward.reshape(-1).tolist() + policy_logits_batch[i] = policy_logits.tolist() + child_is_chance_batch[i] = is_chance + + process_nodes(chance_nodes, True) + process_nodes(decision_nodes, False) + + # latent_state_batch_chance = [latent_state_batch[leaf_idx] for leaf_idx in chance_nodes] + # latent_state_batch_decision = [latent_state_batch[leaf_idx] for leaf_idx in decision_nodes] + value_batch_chance = [value_batch[leaf_idx] for leaf_idx in chance_nodes] + value_batch_decision = [value_batch[leaf_idx] for leaf_idx in decision_nodes] + reward_batch_chance = [reward_batch[leaf_idx] for leaf_idx in chance_nodes] + reward_batch_decision = [reward_batch[leaf_idx] for leaf_idx in decision_nodes] + policy_logits_batch_chance = [policy_logits_batch[leaf_idx] for leaf_idx in chance_nodes] + policy_logits_batch_decision = [policy_logits_batch[leaf_idx] for leaf_idx in decision_nodes] + + latent_state_batch = np.concatenate(latent_state_batch, axis=0) + latent_state_batch_in_search_path.append(latent_state_batch) + current_latent_state_index = simulation_index + 1 + + if(len(chance_nodes) > 0): + value_batch_chance = np.concatenate(value_batch_chance, axis=0) + reward_batch_chance = np.concatenate(reward_batch_chance, axis=0) + policy_logits_batch_chance = np.concatenate(policy_logits_batch_chance, axis=0) + tree_muzero.batch_backpropagate( + current_latent_state_index, discount_factor, reward_batch_chance, value_batch_chance, policy_logits_batch_chance, + min_max_stats_lst, results, virtual_to_play, child_is_chance_batch, chance_nodes + ) + if(len(decision_nodes)>0): + value_batch_decision = np.concatenate(value_batch_decision, axis=0) + reward_batch_decision = np.concatenate(reward_batch_decision, axis=0) + policy_logits_batch_decision = np.concatenate(policy_logits_batch_decision, axis=0) + tree_muzero.batch_backpropagate( + current_latent_state_index, discount_factor, reward_batch_decision, value_batch_decision, policy_logits_batch_decision, + min_max_stats_lst, results, virtual_to_play, child_is_chance_batch, decision_nodes + ) + + # latent_state_batch = np.concatenate(latent_state_batch, axis=0) + # value_batch = np.concatenate(value_batch, axis=0) + # reward_batch = np.concatenate(reward_batch, axis=0) + # policy_logits_batch = np.concatenate(policy_logits_batch, axis=0) + # latent_state_batch_in_search_path.append(latent_state_batch) + + # In ``batch_backpropagate()``, we first expand the leaf node using ``the policy_logits`` and + # ``reward`` predicted by the model, then perform backpropagation along the search path to update the + # statistics. + + # NOTE: simulation_index + 1 is very important, which is the depth of the current leaf node. + # current_latent_state_index = simulation_index + 1 + # tree_muzero.batch_backpropagate( + # current_latent_state_index, discount_factor, reward_batch, value_batch, policy_logits_batch, + # min_max_stats_lst, results, virtual_to_play, child_is_chance_batch + # ) diff --git a/lzero/model/stochastic_muzero_model.py b/lzero/model/stochastic_muzero_model.py new file mode 100644 index 000000000..95bc1c5ba --- /dev/null +++ b/lzero/model/stochastic_muzero_model.py @@ -0,0 +1,873 @@ +""" +Overview: + BTW, users can refer to the unittest of these model templates to learn how to use them. +""" +from typing import Optional, Tuple + +import math +import torch +import torch.nn as nn +from ding.torch_utils import MLP, ResBlock +from ding.utils import MODEL_REGISTRY, SequenceType + +from .common import MZNetworkOutput, RepresentationNetwork, PredictionNetwork +from .utils import renormalize, get_params_mean, get_dynamic_mean, get_reward_mean + + +# use ModelRegistry to register the model, for more details about ModelRegistry, please refer to DI-engine's document. +@MODEL_REGISTRY.register('StochasticMuZeroModel') +class StochasticMuZeroModel(nn.Module): + + def __init__( + self, + observation_shape: SequenceType = (12, 96, 96), + action_space_size: int = 6, + chance_space_size: int = 2, + num_res_blocks: int = 1, + num_channels: int = 64, + reward_head_channels: int = 16, + value_head_channels: int = 16, + policy_head_channels: int = 16, + fc_reward_layers: SequenceType = [32], + fc_value_layers: SequenceType = [32], + fc_policy_layers: SequenceType = [32], + reward_support_size: int = 601, + value_support_size: int = 601, + proj_hid: int = 1024, + proj_out: int = 1024, + pred_hid: int = 512, + pred_out: int = 1024, + self_supervised_learning_loss: bool = False, + categorical_distribution: bool = True, + activation: nn.Module = nn.ReLU(inplace=True), + last_linear_layer_init_zero: bool = True, + state_norm: bool = False, + downsample: bool = False, + *args, + **kwargs + ): + """ + Overview: + The definition of the neural network model used in MuZero. + MuZero model which consists of a representation network, a dynamics network and a prediction network. + The networks are build on convolution residual blocks and fully connected layers. + Arguments: + - observation_shape (:obj:`SequenceType`): Observation space shape, e.g. [C, W, H]=[12, 96, 96] for Atari. + - action_space_size: (:obj:`int`): Action space size, usually an integer number for discrete action space. + - num_res_blocks (:obj:`int`): The number of res blocks in AlphaZero model. + - num_channels (:obj:`int`): The channels of hidden states. + - reward_head_channels (:obj:`int`): The channels of reward head. + - value_head_channels (:obj:`int`): The channels of value head. + - policy_head_channels (:obj:`int`): The channels of policy head. + - fc_reward_layers (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). + - fc_value_layers (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). + - fc_policy_layers (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). + - reward_support_size (:obj:`int`): The size of categorical reward output + - value_support_size (:obj:`int`): The size of categorical value output. + - proj_hid (:obj:`int`): The size of projection hidden layer. + - proj_out (:obj:`int`): The size of projection output layer. + - pred_hid (:obj:`int`): The size of prediction hidden layer. + - pred_out (:obj:`int`): The size of prediction output layer. + - self_supervised_learning_loss (:obj:`bool`): Whether to use self_supervised_learning related networks \ + in MuZero model, default set it to False. + - categorical_distribution (:obj:`bool`): Whether to use discrete support to represent categorical \ + distribution for value and reward. + - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ + operation to speedup, e.g. ReLU(inplace=True). + - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initialization for the last layer of \ + dynamics/prediction mlp, default set it to True. + - state_norm (:obj:`bool`): Whether to use normalization for hidden states, default set it to False. + - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \ + defaults to True. This option is often used in video games like Atari. In board games like go, \ + we don't need this module. + """ + super(StochasticMuZeroModel, self).__init__() + if isinstance(observation_shape, int) or len(observation_shape) == 1: + # for vector obs input, e.g. classical control ad box2d environments + # to be compatible with LightZero model/policy, transform to shape: [C, W, H] + observation_shape = [1, observation_shape, 1] + + self.categorical_distribution = categorical_distribution + if self.categorical_distribution: + self.reward_support_size = reward_support_size + self.value_support_size = value_support_size + else: + self.reward_support_size = 1 + self.value_support_size = 1 + + self.action_space_size = action_space_size + self.chance_space_size = chance_space_size + + self.proj_hid = proj_hid + self.proj_out = proj_out + self.pred_hid = pred_hid + self.pred_out = pred_out + self.self_supervised_learning_loss = self_supervised_learning_loss + self.last_linear_layer_init_zero = last_linear_layer_init_zero + self.state_norm = state_norm + self.downsample = downsample + + flatten_output_size_for_reward_head = ( + (reward_head_channels * math.ceil(observation_shape[1] / 16) * + math.ceil(observation_shape[2] / 16)) if downsample else + (reward_head_channels * observation_shape[1] * observation_shape[2]) + ) + flatten_output_size_for_value_head = ( + (value_head_channels * math.ceil(observation_shape[1] / 16) * + math.ceil(observation_shape[2] / 16)) if downsample else + (value_head_channels * observation_shape[1] * observation_shape[2]) + ) + flatten_output_size_for_policy_head = ( + (policy_head_channels * math.ceil(observation_shape[1] / 16) * + math.ceil(observation_shape[2] / 16)) if downsample else + (policy_head_channels * observation_shape[1] * observation_shape[2]) + ) + + self.representation_network = RepresentationNetwork( + observation_shape, + num_res_blocks, + num_channels, + downsample, + ) + + self.encoder = Encoder_function( + observation_shape, chance_space_size + ) + self.dynamics_network = DynamicsNetwork( + num_res_blocks, + num_channels + 1, + reward_head_channels, + fc_reward_layers, + self.reward_support_size, + flatten_output_size_for_reward_head, + last_linear_layer_init_zero=self.last_linear_layer_init_zero, + ) + self.prediction_network = PredictionNetwork( + observation_shape, + action_space_size, + num_res_blocks, + num_channels, + value_head_channels, + policy_head_channels, + fc_value_layers, + fc_policy_layers, + self.value_support_size, + flatten_output_size_for_value_head, + flatten_output_size_for_policy_head, + last_linear_layer_init_zero=self.last_linear_layer_init_zero, + ) + + self.afterstate_dynamics_network = AfterstateDynamicsNetwork( + num_res_blocks, + num_channels + 1, + reward_head_channels, + fc_reward_layers, + self.reward_support_size, + flatten_output_size_for_reward_head, + last_linear_layer_init_zero=self.last_linear_layer_init_zero, + ) + self.afterstate_prediction_network = AfterstatePredictionNetwork( + chance_space_size, + num_res_blocks, + num_channels, + value_head_channels, + policy_head_channels, + fc_value_layers, + fc_policy_layers, + self.value_support_size, + flatten_output_size_for_value_head, + flatten_output_size_for_policy_head, + last_linear_layer_init_zero=self.last_linear_layer_init_zero, + ) + + if self.self_supervised_learning_loss: + # projection used in EfficientZero + if self.downsample: + # In Atari, if the observation_shape is set to (12, 96, 96), which indicates the original shape of + # (3,96,96), and frame_stack_num is 4. Due to downsample, the encoding of observation (latent_state) is + # (64, 96/16, 96/16), where 64 is the number of channels, 96/16 is the size of the latent state. Thus, + # self.projection_input_dim = 64 * 96/16 * 96/16 = 64*6*6 = 2304 + ceil_size = math.ceil(observation_shape[1] / 16) * math.ceil(observation_shape[2] / 16) + self.projection_input_dim = num_channels * ceil_size + else: + self.projection_input_dim = num_channels * observation_shape[1] * observation_shape[2] + + self.projection = nn.Sequential( + nn.Linear(self.projection_input_dim, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, + nn.Linear(self.proj_hid, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, + nn.Linear(self.proj_hid, self.proj_out), nn.BatchNorm1d(self.proj_out) + ) + self.prediction_head = nn.Sequential( + nn.Linear(self.proj_out, self.pred_hid), + nn.BatchNorm1d(self.pred_hid), + activation, + nn.Linear(self.pred_hid, self.pred_out), + ) + + def initial_inference(self, obs: torch.Tensor) -> MZNetworkOutput: + """ + Overview: + Initial inference of MuZero model, which is the first step of the MuZero model. + To perform the initial inference, we first use the representation network to obtain the ``latent_state``. + Then we use the prediction network to predict ``value`` and ``policy_logits`` of the ``latent_state``. + Arguments: + - obs (:obj:`torch.Tensor`): The 2D image observation data. + Returns (MZNetworkOutput): + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + - reward (:obj:`torch.Tensor`): The predicted reward of input state and selected action. \ + In initial inference, we set it to zero vector. + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, num_channel, obs_shape[1], obs_shape[2])`, where B is batch_size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. + - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. + - latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ + latent state, W_ is the width of latent state. + """ + batch_size = obs.size(0) + latent_state = self._representation(obs) + policy_logits, value = self._prediction(latent_state) + return MZNetworkOutput( + value, + [0. for _ in range(batch_size)], + policy_logits, + latent_state, + ) + + def recurrent_inference(self, state: torch.Tensor, option: torch.Tensor, afterstate: bool = False) -> MZNetworkOutput: + """ + Overview: + Recurrent inference of MuZero model, which is the rollout step of the MuZero model. + To perform the recurrent inference, we first use the dynamics network to predict ``next_latent_state``, + ``reward``, by the given current ``latent_state`` and ``action``. + We then use the prediction network to predict the ``value`` and ``policy_logits`` of the current + ``latent_state``. + Arguments: + - state (:obj:`torch.Tensor`): The encoding latent state of input state or the afterstate. + - option (:obj:`torch.Tensor`): The action to rollout or the chance to predict next latent state. + - afterstate (:obj:`bool`): Whether to use afterstate prediction network to predict next latent state. + Returns (MZNetworkOutput): + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + - reward (:obj:`torch.Tensor`): The predicted reward of input state and selected action. + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - next_latent_state (:obj:`torch.Tensor`): The predicted next latent state. + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, num_channel, obs_shape[1], obs_shape[2])`, where B is batch_size. + - action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. + - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. + - latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ + latent state, W_ is the width of latent state. + - next_latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ + latent state, W_ is the width of latent state. + """ + + if afterstate: + # state is afterstate, option is chance + next_latent_state, reward = self._dynamics(state, option) + policy_logits, value = self._prediction(next_latent_state) + return MZNetworkOutput(value, reward, policy_logits, next_latent_state) + else: + # state is latent_state, option is action + next_afterstate, reward = self._afterstate_dynamics(state, option) + policy_logits, value = self._afterstate_prediction(next_afterstate) + return MZNetworkOutput(value, reward, policy_logits, next_afterstate) + + def _representation(self, observation: torch.Tensor) -> torch.Tensor: + """ + Overview: + Use the representation network to encode the observations into latent state. + Arguments: + - obs (:obj:`torch.Tensor`): The 2D image observation data. + Returns: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, num_channel, obs_shape[1], obs_shape[2])`, where B is batch_size. + - latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ + latent state, W_ is the width of latent state. + """ + latent_state = self.representation_network(observation) + if self.state_norm: + latent_state = renormalize(latent_state) + return latent_state + + def _encode_vqvae(self, observation: torch.Tensor): + output = self.encoder(observation) + return output + + def _prediction(self, latent_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Use the prediction network to predict ``policy_logits`` and ``value``. + Arguments: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + Returns: + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + Shapes: + - latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ + latent state, W_ is the width of latent state. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. + """ + return self.prediction_network(latent_state) + + def _afterstate_prediction(self, afterstate: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Use the prediction network to predict ``policy_logits`` and ``value``. + Arguments: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + Returns: + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + Shapes: + - latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ + latent state, W_ is the width of latent state. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. + """ + return self.afterstate_prediction_network(afterstate) + + def _dynamics(self, latent_state: torch.Tensor, action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Concatenate ``latent_state`` and ``action`` and use the dynamics network to predict ``next_latent_state`` + and ``reward``. + Arguments: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - action (:obj:`torch.Tensor`): The predicted action to rollout. + Returns: + - next_latent_state (:obj:`torch.Tensor`): The predicted latent state of the next timestep. + - reward (:obj:`torch.Tensor`): The predicted reward of the current latent state and selected action. + Shapes: + - latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ + latent state, W_ is the width of latent state. + - action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size. + - next_latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ + latent state, W_ is the width of latent state. + - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + """ + # NOTE: the discrete action encoding type is important for some environments + + # discrete action space + # the final action_encoding shape is (batch_size, 1, latent_state[2], latent_state[3]), e.g. (8, 1, 4, 1). + action_encoding = ( + torch.ones(( + latent_state.shape[0], + 1, + latent_state.shape[2], + latent_state.shape[3], + )).to(action.device).float() + ) + if len(action.shape) == 2: + # (batch_size, action_dim) -> (batch_size, action_dim, 1) + # e.g., torch.Size([8, 1]) -> torch.Size([8, 1, 1]) + action = action.unsqueeze(-1) + elif len(action.shape) == 1: + # (batch_size,) -> (batch_size, action_dim=1, 1) + # e.g., -> torch.Size([8, 1]) -> torch.Size([8, 1, 1]) + action = action.unsqueeze(-1).unsqueeze(-1) + + # action[:, 0, None, None] shape: (batch_size, action_dim, 1, 1) e.g. (8, 1, 1, 1) + # the final action_encoding shape: (batch_size, 1, latent_state[2], latent_state[3]) e.g. (8, 1, 4, 1), + # where each element is normalized as action[i]/action_space_size + action_encoding = (action[:, 0, None, None] * action_encoding / self.chance_space_size) + + # state_action_encoding shape: (batch_size, latent_state[1] + 1, latent_state[2], latent_state[3]) + state_action_encoding = torch.cat((latent_state, action_encoding), dim=1) + + next_latent_state, reward = self.dynamics_network(state_action_encoding) + if self.state_norm: + next_latent_state = renormalize(next_latent_state) + return next_latent_state, reward + + def _afterstate_dynamics(self, latent_state: torch.Tensor, action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Concatenate ``latent_state`` and ``action`` and use the dynamics network to predict ``next_latent_state`` + and ``reward``. + Arguments: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - action (:obj:`torch.Tensor`): The predicted action to rollout. + Returns: + - next_latent_state (:obj:`torch.Tensor`): The predicted latent state of the next timestep. + - reward (:obj:`torch.Tensor`): The predicted reward of the current latent state and selected action. + Shapes: + - latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ + latent state, W_ is the width of latent state. + - action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size. + - next_latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ + latent state, W_ is the width of latent state. + - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + """ + # NOTE: the discrete action encoding type is important for some environments + + # discrete action space + # the final action_encoding shape is (batch_size, 1, latent_state[2], latent_state[3]), e.g. (8, 1, 4, 1). + action_encoding = ( + torch.ones(( + latent_state.shape[0], + 1, + latent_state.shape[2], + latent_state.shape[3], + )).to(action.device).float() + ) + if len(action.shape) == 2: + # (batch_size, action_dim) -> (batch_size, action_dim, 1) + # e.g., torch.Size([8, 1]) -> torch.Size([8, 1, 1]) + action = action.unsqueeze(-1) + elif len(action.shape) == 1: + # (batch_size,) -> (batch_size, action_dim=1, 1) + # e.g., -> torch.Size([8, 1]) -> torch.Size([8, 1, 1]) + action = action.unsqueeze(-1).unsqueeze(-1) + + # action[:, 0, None, None] shape: (batch_size, action_dim, 1, 1) e.g. (8, 1, 1, 1) + # the final action_encoding shape: (batch_size, 1, latent_state[2], latent_state[3]) e.g. (8, 1, 4, 1), + # where each element is normalized as action[i]/action_space_size + action_encoding = (action[:, 0, None, None] * action_encoding / self.action_space_size) + + # state_action_encoding shape: (batch_size, latent_state[1] + 1, latent_state[2], latent_state[3]) + state_action_encoding = torch.cat((latent_state, action_encoding), dim=1) + + next_latent_state, reward = self.afterstate_dynamics_network(state_action_encoding) + if self.state_norm: + next_latent_state = renormalize(next_latent_state) + return next_latent_state, reward + + + + def project(self, latent_state: torch.Tensor, with_grad: bool = True) -> torch.Tensor: + """ + Overview: + Project the latent state to a lower dimension to calculate the self-supervised loss, which is involved in + MuZero algorithm in EfficientZero. + For more details, please refer to paper ``Exploring Simple Siamese Representation Learning``. + Arguments: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - with_grad (:obj:`bool`): Whether to calculate gradient for the projection result. + Returns: + - proj (:obj:`torch.Tensor`): The result embedding vector of projection operation. + Shapes: + - latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ + latent state, W_ is the width of latent state. + - proj (:obj:`torch.Tensor`): :math:`(B, projection_output_dim)`, where B is batch_size. + + Examples: + >>> latent_state = torch.randn(256, 64, 6, 6) + >>> output = self.project(latent_state) + >>> output.shape # (256, 1024) + + .. note:: + for Atari: + observation_shape = (12, 96, 96), # original shape is (3,96,96), frame_stack_num=4 + if downsample is True, latent_state.shape: (batch_size, num_channel, obs_shape[1] / 16, obs_shape[2] / 16) + i.e., (256, 64, 96 / 16, 96 / 16) = (256, 64, 6, 6) + latent_state reshape: (256, 64, 6, 6) -> (256,64*6*6) = (256, 2304) + # self.projection_input_dim = 64*6*6 = 2304 + # self.projection_output_dim = 1024 + """ + latent_state = latent_state.reshape(latent_state.shape[0], -1) + proj = self.projection(latent_state) + + if with_grad: + # with grad, use prediction_head + return self.prediction_head(proj) + else: + return proj.detach() + + def get_params_mean(self) -> float: + return get_params_mean(self) + + +class DynamicsNetwork(nn.Module): + + def __init__( + self, + num_res_blocks: int, + num_channels: int, + reward_head_channels: int, + fc_reward_layers: SequenceType, + output_support_size: int, + flatten_output_size_for_reward_head: int, + last_linear_layer_init_zero: bool = True, + activation: Optional[nn.Module] = nn.ReLU(inplace=True), + ): + """ + Overview: + The definition of dynamics network in MuZero algorithm, which is used to predict next latent state and + reward given current latent state and action. + Arguments: + - num_res_blocks (:obj:`int`): The number of res blocks in AlphaZero model. + - num_channels (:obj:`int`): The channels of input, including obs and action encoding. + - reward_head_channels (:obj:`int`): The channels of reward head. + - fc_reward_layers (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). + - output_support_size (:obj:`int`): The size of categorical reward output. + - flatten_output_size_for_reward_head (:obj:`int`): The flatten size of output for reward head, i.e., \ + the input size of reward head. + - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initialization for the last layer of \ + reward mlp, default set it to True. + - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ + operation to speedup, e.g. ReLU(inplace=True). + """ + super().__init__() + self.num_channels = num_channels + self.flatten_output_size_for_reward_head = flatten_output_size_for_reward_head + + self.conv = nn.Conv2d(num_channels, num_channels - 1, kernel_size=3, stride=1, padding=1, bias=False) + self.bn = nn.BatchNorm2d(num_channels - 1) + self.resblocks = nn.ModuleList( + [ + ResBlock( + in_channels=num_channels - 1, activation=activation, norm_type='BN', res_type='basic', bias=False + ) for _ in range(num_res_blocks) + ] + ) + + self.conv1x1_reward = nn.Conv2d(num_channels - 1, reward_head_channels, 1) + self.bn_reward = nn.BatchNorm2d(reward_head_channels) + self.fc_reward_head = MLP( + self.flatten_output_size_for_reward_head, + hidden_channels=fc_reward_layers[0], + layer_num=len(fc_reward_layers) + 1, + out_channels=output_support_size, + activation=activation, + norm_type='BN', + output_activation=False, + output_norm=False, + last_linear_layer_init_zero=last_linear_layer_init_zero + ) + self.activation = activation + + def forward(self, state_action_encoding: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Forward computation of the dynamics network. Predict next latent state given current latent state and action. + Arguments: + - state_action_encoding (:obj:`torch.Tensor`): The state-action encoding, which is the concatenation of \ + latent state and action encoding, with shape (batch_size, num_channels, height, width). + Returns: + - next_latent_state (:obj:`torch.Tensor`): The next latent state, with shape (batch_size, num_channels, \ + height, width). + - reward (:obj:`torch.Tensor`): The predicted reward, with shape (batch_size, output_support_size). + """ + # take the state encoding (latent_state), state_action_encoding[:, -1, :, :] is action encoding + latent_state = state_action_encoding[:, :-1, :, :] + x = self.conv(state_action_encoding) + x = self.bn(x) + + # the residual link: add state encoding to the state_action encoding + x += latent_state + x = self.activation(x) + + for block in self.resblocks: + x = block(x) + next_latent_state = x + + x = self.conv1x1_reward(next_latent_state) + x = self.bn_reward(x) + x = self.activation(x) + x = x.view(-1, self.flatten_output_size_for_reward_head) + + # use the fully connected layer to predict reward + reward = self.fc_reward_head(x) + + return next_latent_state, reward + + def get_dynamic_mean(self) -> float: + return get_dynamic_mean(self) + + def get_reward_mean(self) -> float: + return get_reward_mean(self) + +class AfterstateDynamicsNetwork(nn.Module): + + def __init__( + self, + num_res_blocks: int, + num_channels: int, + reward_head_channels: int, + fc_reward_layers: SequenceType, + output_support_size: int, + flatten_output_size_for_reward_head: int, + last_linear_layer_init_zero: bool = True, + activation: Optional[nn.Module] = nn.ReLU(inplace=True), + ): + """ + Overview: + The definition of dynamics network in MuZero algorithm, which is used to predict next latent state and + reward given current latent state and action. + Arguments: + - num_res_blocks (:obj:`int`): The number of res blocks in AlphaZero model. + - num_channels (:obj:`int`): The channels of input, including obs and action encoding. + - reward_head_channels (:obj:`int`): The channels of reward head. + - fc_reward_layers (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). + - output_support_size (:obj:`int`): The size of categorical reward output. + - flatten_output_size_for_reward_head (:obj:`int`): The flatten size of output for reward head, i.e., \ + the input size of reward head. + - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initialization for the last layer of \ + reward mlp, default set it to True. + - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ + operation to speedup, e.g. ReLU(inplace=True). + """ + super().__init__() + self.num_channels = num_channels + self.flatten_output_size_for_reward_head = flatten_output_size_for_reward_head + self.conv = nn.Conv2d(num_channels, num_channels - 1, kernel_size=3, stride=1, padding=1, bias=False) + self.bn = nn.BatchNorm2d(num_channels - 1) + self.resblocks = nn.ModuleList( + [ + ResBlock( + in_channels=num_channels - 1, activation=activation, norm_type='BN', res_type='basic', bias=False + ) for _ in range(num_res_blocks) + ] + ) + self.conv1x1_reward = nn.Conv2d(num_channels - 1, reward_head_channels, 1) + self.bn_reward = nn.BatchNorm2d(reward_head_channels) + self.fc_reward_head = MLP( + self.flatten_output_size_for_reward_head, + hidden_channels=fc_reward_layers[0], + layer_num=len(fc_reward_layers) + 1, + out_channels=output_support_size, + activation=activation, + norm_type='BN', + output_activation=False, + output_norm=False, + last_linear_layer_init_zero=last_linear_layer_init_zero + ) + self.activation = activation + + def forward(self, state_action_encoding: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Forward computation of the dynamics network. Predict next latent state given current latent state and action. + Arguments: + - state_action_encoding (:obj:`torch.Tensor`): The state-action encoding, which is the concatenation of \ + latent state and action encoding, with shape (batch_size, num_channels, height, width). + Returns: + - next_latent_state (:obj:`torch.Tensor`): The next latent state, with shape (batch_size, num_channels, \ + height, width). + - reward (:obj:`torch.Tensor`): The predicted reward, with shape (batch_size, output_support_size). + """ + # take the state encoding (latent_state), state_action_encoding[:, -1, :, :] is action encoding + latent_state = state_action_encoding[:, :-1, :, :] + x = self.conv(state_action_encoding) + x = self.bn(x) + + # the residual link: add state encoding to the state_action encoding + x += latent_state + x = self.activation(x) + + for block in self.resblocks: + x = block(x) + afterstate_latent_state = x + # reward = None + + x = self.conv1x1_reward(afterstate_latent_state) + x = self.bn_reward(x) + x = self.activation(x) + x = x.view(-1, self.flatten_output_size_for_reward_head) + + # use the fully connected layer to predict reward + reward = self.fc_reward_head(x) + + return afterstate_latent_state, reward + + def get_dynamic_mean(self) -> float: + return get_dynamic_mean(self) + + def get_reward_mean(self) -> float: + return get_reward_mean(self) + + +class AfterstatePredictionNetwork(nn.Module): + def __init__( + self, + action_space_size: int, + num_res_blocks: int, + num_channels: int, + value_head_channels: int, + policy_head_channels: int, + fc_value_layers: int, + fc_policy_layers: int, + output_support_size: int, + flatten_output_size_for_value_head: int, + flatten_output_size_for_policy_head: int, + last_linear_layer_init_zero: bool = True, + activation: nn.Module = nn.ReLU(inplace=True), + ) -> None: + """ + Overview: + The definition of policy and value prediction network, which is used to predict value and policy by the + given latent state. + Arguments: + - action_space_size: (:obj:`int`): Action space size, usually an integer number for discrete action space. + - num_res_blocks (:obj:`int`): The number of res blocks in AlphaZero model. + - num_channels (:obj:`int`): The channels of hidden states. + - value_head_channels (:obj:`int`): The channels of value head. + - policy_head_channels (:obj:`int`): The channels of policy head. + - fc_value_layers (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). + - fc_policy_layers (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). + - output_support_size (:obj:`int`): The size of categorical value output. + - self_supervised_learning_loss (:obj:`bool`): Whether to use self_supervised_learning related networks \ + - flatten_output_size_for_value_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \ + of the value head. + - flatten_output_size_for_policy_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \ + of the policy head. + - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initialization for the last layer of \ + dynamics/prediction mlp, default set it to True. + - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ + operation to speedup, e.g. ReLU(inplace=True). + """ + super(AfterstatePredictionNetwork, self).__init__() + self.resblocks = nn.ModuleList( + [ + ResBlock(in_channels=num_channels, activation=activation, norm_type='BN', res_type='basic', bias=False) + for _ in range(num_res_blocks) + ] + ) + self.conv1x1_value = nn.Conv2d(num_channels, value_head_channels, 1) + self.conv1x1_policy = nn.Conv2d(num_channels, policy_head_channels, 1) + self.bn_value = nn.BatchNorm2d(value_head_channels) + self.bn_policy = nn.BatchNorm2d(policy_head_channels) + self.flatten_output_size_for_value_head = flatten_output_size_for_value_head + self.flatten_output_size_for_policy_head = flatten_output_size_for_policy_head + self.activation = activation + + self.fc_value = MLP( + in_channels=self.flatten_output_size_for_value_head, + hidden_channels=fc_value_layers[0], + out_channels=output_support_size, + layer_num=len(fc_value_layers) + 1, + activation=self.activation, + norm_type='BN', + output_activation=False, + output_norm=False, + last_linear_layer_init_zero=last_linear_layer_init_zero + ) + self.fc_policy = MLP( + in_channels=self.flatten_output_size_for_policy_head, + hidden_channels=fc_policy_layers[0], + out_channels=action_space_size, + layer_num=len(fc_policy_layers) + 1, + activation=self.activation, + norm_type='BN', + output_activation=False, + output_norm=False, + last_linear_layer_init_zero=last_linear_layer_init_zero + ) + + + def forward(self, latent_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Forward computation of the prediction network. + Arguments: + - latent_state (:obj:`torch.Tensor`): input tensor with shape (B, latent_state_dim). + Returns: + - policy (:obj:`torch.Tensor`): policy tensor with shape (B, action_space_size). + - value (:obj:`torch.Tensor`): value tensor with shape (B, output_support_size). + """ + for res_block in self.resblocks: + latent_state = res_block(latent_state) + + value = self.conv1x1_value(latent_state) + value = self.bn_value(value) + value = self.activation(value) + + policy = self.conv1x1_policy(latent_state) + policy = self.bn_policy(policy) + policy = self.activation(policy) + + value = value.reshape(-1, self.flatten_output_size_for_value_head) + policy = policy.reshape(-1, self.flatten_output_size_for_policy_head) + + value = self.fc_value(value) + policy = self.fc_policy(policy) + return policy, value + +class ImgNet(nn.Module): + def __init__(self, observation_space_dimensions, table_vec_dim=4): + super(ImgNet, self).__init__() + self.conv1 = nn.Conv2d(observation_space_dimensions[0]*2, 32, 3, padding=1) + self.conv2 = nn.Conv2d(32, 64, 3, padding=1) + self.fc1 = nn.Linear(64 * observation_space_dimensions[1] * observation_space_dimensions[2], 128) + self.fc2 = nn.Linear(128, 64) + self.fc3 = nn.Linear(64, table_vec_dim) + + def forward(self, x): + x = torch.relu(self.conv1(x)) + x = torch.relu(self.conv2(x)) + # x = x.view(-1, 64 * 4 * 4) + x = x.view(x.shape[0], -1) + x = torch.relu(self.fc1(x)) + x = torch.relu(self.fc2(x)) + x = self.fc3(x) + return x + + +class Encoder_function(nn.Module): + def __init__(self, + observation_space_dimensions, + action_dimension): + super().__init__() + self.action_space = action_dimension + self.encoder = ImgNet(observation_space_dimensions, action_dimension) + # # # # add to sequence|first and recursive|,, whatever you need + # linear_in = nn.Linear(observation_space_dimensions, hidden_layer_dimensions) + # linear_mid = nn.Linear(hidden_layer_dimensions, hidden_layer_dimensions) + # linear_out = nn.Linear(hidden_layer_dimensions, state_dimension) + + # self.scale = nn.Tanh() + # layernom_init = nn.BatchNorm1d(observation_space_dimensions) + # layernorm_recur = nn.BatchNorm1d(hidden_layer_dimensions) + # # 0.1, 0.2 , 0.25 , 0.5 parameter (first two more recommended for rl) + # dropout = nn.Dropout(0.1) + # activation = nn.ELU() # , nn.ELU() , nn.GELU, nn.ELU() , nn.ELU + + # first_layer_sequence = [ + # linear_in, + # activation + # ] + + # recursive_layer_sequence = [ + # linear_mid, + # activation + # ] + + # sequence = first_layer_sequence + \ + # (recursive_layer_sequence*number_of_hidden_layer) + + # self.encoder = nn.Sequential(*tuple(sequence+[nn.Linear(hidden_layer_dimensions, action_dimension)])) + self.onehot_argmax = StraightThroughEstimator() + def forward(self, o_i): + #https://openreview.net/pdf?id=X6D9bAHhBQ1 [page:5 chance outcome] + c_e_t = torch.nn.Softmax(-1)(self.encoder(o_i)) + #c_t= torch.zeros_like(c_e_t).scatter_(-1, torch.argmax(c_e_t, dim=-1,keepdim=True), 1.) + c_t = self.onehot_argmax(c_e_t) + return c_t,c_e_t + + + +class StraightThroughEstimator(nn.Module): + def __init__(self): + super(StraightThroughEstimator, self).__init__() + + def forward(self, x): + x = Onehot_argmax.apply(x) + return x +#straight-through estimator is used during the backward to allow the gradients to flow only to the encoder during the backpropagation. +class Onehot_argmax(torch.autograd.Function): + #more information at : https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html + @staticmethod + def forward(ctx, input): + #since the codebook is constant ,we can just use a transformation. no need to create a codebook and matmul c_e_t and codebook for argmax + return torch.zeros_like(input).scatter_(-1, torch.argmax(input, dim=-1,keepdim=True), 1.) + + @staticmethod + def backward(ctx, grad_output): + return grad_output \ No newline at end of file diff --git a/lzero/policy/stochastic_muzero.py b/lzero/policy/stochastic_muzero.py new file mode 100644 index 000000000..401095952 --- /dev/null +++ b/lzero/policy/stochastic_muzero.py @@ -0,0 +1,738 @@ +import copy +from typing import List, Dict, Any, Tuple, Union + +import numpy as np +import torch +import torch.optim as optim +from ding.model import model_wrap +from ding.policy.base_policy import Policy +from ding.torch_utils import to_tensor +from ding.utils import POLICY_REGISTRY +from torch.nn import L1Loss + +# from lzero.mcts import StochasticMuZeroMCTSCtree as MCTSCtree +from lzero.mcts import StochasticMuZeroMCTSPtree as MCTSPtree +from lzero.model import ImageTransforms +from lzero.policy import scalar_transform, InverseScalarTransform, cross_entropy_loss, phi_transform, \ + DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, negative_cosine_similarity, prepare_obs, \ + configure_optimizers + + +@POLICY_REGISTRY.register('stochastic_muzero') +class StochasticMuZeroPolicy(Policy): + """ + Overview: + The policy class for MuZero. + """ + + # The default_config for MuZero policy. + config = dict( + model=dict( + # (str) The model type. For 1-dimensional vector obs, we use mlp model. For the image obs, we use conv model. + model_type='conv', # options={'mlp', 'conv'} + # (bool) If True, the action space of the environment is continuous, otherwise discrete. + continuous_action_space=False, + # (tuple) The stacked obs shape. + # observation_shape=(1, 96, 96), # if frame_stack_num=1 + observation_shape=(4, 96, 96), # if frame_stack_num=4 + # (int) The chance space size. + chance_space_size=2, + # (bool) Whether to use the self-supervised learning loss. + self_supervised_learning_loss=False, + # (bool) Whether to use discrete support to represent categorical distribution for value/reward/value_prefix. + categorical_distribution=True, + # (int) The image channel in image observation. + image_channel=1, + # (int) The number of frames to stack together. + frame_stack_num=1, + # (int) The number of res blocks in MuZero model. + num_res_blocks=1, + # (int) The number of channels of hidden states in MuZero model. + num_channels=64, + # (int) The scale of supports used in categorical distribution. + # This variable is only effective when ``categorical_distribution=True``. + support_scale=300, + # (bool) whether to learn bias in the last linear layer in value and policy head. + bias=True, + ), + # ****** common ****** + # (bool) Whether to enable the sampled-based algorithm (e.g. Sampled EfficientZero) + # this variable is used in ``collector``. + sampled_algo=False, + # (bool) Whether to use C++ MCTS in policy. If False, use Python implementation. + mcts_ctree=True, + # (bool) Whether to use cuda for network. + cuda=True, + # (int) The number of environments used in collecting data. + collector_env_num=8, + # (int) The number of environments used in evaluating policy. + evaluator_env_num=3, + # (str) The type of environment. Options is ['not_board_games', 'board_games']. + env_type='not_board_games', + # (str) The type of battle mode. Options is ['play_with_bot_mode', 'self_play_mode']. + battle_mode='play_with_bot_mode', + # (bool) Whether to monitor extra statistics in tensorboard. + monitor_extra_statistics=True, + # (int) The transition number of one ``GameSegment``. + game_segment_length=200, + + # ****** observation ****** + # (bool) Whether to transform image to string to save memory. + transform2string=False, + # (bool) Whether to use data augmentation. + use_augmentation=False, + # (list) The style of augmentation. + augmentation=['shift', 'intensity'], + + # ******* learn ****** + # (int) How many updates(iterations) to train after collector's one collection. + # Bigger "update_per_collect" means bigger off-policy. + # collect data -> update policy-> collect data -> ... + # For different env, we have different episode_length, + # we usually set update_per_collect = collector_env_num * episode_length / batch_size * reuse_factor + update_per_collect=100, + # (int) Minibatch size for one gradient descent. + batch_size=256, + # (str) Optimizer for training policy network. ['SGD', 'Adam', 'AdamW'] + optim_type='SGD', + # (float) Learning rate for training policy network. Ininitial lr for manually decay schedule. + learning_rate=0.2, + # (int) Frequency of target network update. + target_update_freq=100, + # (float) Weight decay for training policy network. + weight_decay=1e-4, + # (float) One-order Momentum in optimizer, which stabilizes the training process (gradient direction). + momentum=0.9, + # (float) The maximum constraint value of gradient norm clipping. + grad_clip_value=10, + # (int) The number of episode in each collecting stage. + n_episode=8, + # (int) the number of simulations in MCTS. + num_simulations=50, + # (float) Discount factor (gamma) for returns. + discount_factor=0.997, + # (int) The number of step for calculating target q_value. + td_steps=5, + # (int) The number of unroll steps in dynamics network. + num_unroll_steps=5, + # (float) The weight of reward loss. + reward_loss_weight=1, + # (float) The weight of value loss. + value_loss_weight=0.25, + # (float) The weight of policy loss. + policy_loss_weight=1, + # (float) The weight of ssl (self-supervised learning) loss. + ssl_loss_weight=0, + # (bool) Whether to use piecewise constant learning rate decay. + # i.e. lr: 0.2 -> 0.02 -> 0.002 + lr_piecewise_constant_decay=True, + # (int) The number of final training iterations to control lr decay, which is only used for manually decay. + threshold_training_steps_for_final_lr=int(5e4), + # (bool) Whether to use manually decayed temperature. + manual_temperature_decay=False, + # (int) The number of final training iterations to control temperature, which is only used for manually decay. + threshold_training_steps_for_final_temperature=int(1e5), + # (float) The fixed temperature value for MCTS action selection, which is used to control the exploration. + # The larger the value, the more exploration. This value is only used when manual_temperature_decay=False. + fixed_temperature_value=0.25, + + # ****** Priority ****** + # (bool) Whether to use priority when sampling training data from the buffer. + use_priority=True, + # (bool) Whether to use the maximum priority for new collecting data. + use_max_priority_for_new_data=True, + # (float) The degree of prioritization to use. A value of 0 means no prioritization, + # while a value of 1 means full prioritization. + priority_prob_alpha=0.6, + # (float) The degree of correction to use. A value of 0 means no correction, + # while a value of 1 means full correction. + priority_prob_beta=0.4, + + # ****** UCB ****** + # (float) The alpha value used in the Dirichlet distribution for exploration at the root node of search tree. + root_dirichlet_alpha=0.3, + # (float) The noise weight at the root node of the search tree. + root_noise_weight=0.25, + ) + + def default_model(self) -> Tuple[str, List[str]]: + """ + Overview: + Return this algorithm default model setting. + Returns: + - model_info (:obj:`Tuple[str, List[str]]`): model name and model import_names. + - model_type (:obj:`str`): The model type used in this algorithm, which is registered in ModelRegistry. + - import_names (:obj:`List[str]`): The model class path list used in this algorithm. + .. note:: + The user can define and use customized network model but must obey the same interface definition indicated \ + by import_names path. For MuZero, ``lzero.model.muzero_model.MuZeroModel`` + """ + if self._cfg.model.model_type == "conv": + return 'StochasticMuZeroModel', ['lzero.model.stochastic_muzero_model'] + elif self._cfg.model.model_type == "mlp": + return 'StochasticMuZeroModelMLP', ['lzero.model.stochastic_muzero_model_mlp'] + + def _init_learn(self) -> None: + """ + Overview: + Learn mode init method. Called by ``self.__init__``. Ininitialize the learn model, optimizer and MCTS utils. + """ + assert self._cfg.optim_type in ['SGD', 'Adam', 'AdamW'], self._cfg.optim_type + # NOTE: in board_gmaes, for fixed lr 0.003, 'Adam' is better than 'SGD'. + if self._cfg.optim_type == 'SGD': + self._optimizer = optim.SGD( + self._model.parameters(), + lr=self._cfg.learning_rate, + momentum=self._cfg.momentum, + weight_decay=self._cfg.weight_decay, + ) + elif self._cfg.optim_type == 'Adam': + self._optimizer = optim.Adam( + self._model.parameters(), lr=self._cfg.learning_rate, weight_decay=self._cfg.weight_decay + ) + elif self._cfg.optim_type == 'AdamW': + self._optimizer = configure_optimizers(model=self._model, weight_decay=self._cfg.weight_decay, learning_rate=self._cfg.learning_rate, device_type=self._cfg.device) + + if self._cfg.lr_piecewise_constant_decay: + from torch.optim.lr_scheduler import LambdaLR + max_step = self._cfg.threshold_training_steps_for_final_lr + # NOTE: the 1, 0.1, 0.01 is the decay rate, not the lr. + lr_lambda = lambda step: 1 if step < max_step * 0.5 else (0.1 if step < max_step else 0.01) # noqa + self.lr_scheduler = LambdaLR(self._optimizer, lr_lambda=lr_lambda) + + # use model_wrapper for specialized demands of different modes + self._target_model = copy.deepcopy(self._model) + self._target_model = model_wrap( + self._target_model, + wrapper_name='target', + update_type='assign', + update_kwargs={'freq': self._cfg.target_update_freq} + ) + self._learn_model = self._model + + if self._cfg.use_augmentation: + self.image_transforms = ImageTransforms( + self._cfg.augmentation, + image_shape=(self._cfg.model.observation_shape[1], self._cfg.model.observation_shape[2]) + ) + self.value_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) + self.reward_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) + self.inverse_scalar_transform_handle = InverseScalarTransform( + self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution + ) + + def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]: + """ + Overview: + The forward function for learning policy in learn mode, which is the core of the learning process. + The data is sampled from replay buffer. + The loss is calculated by the loss function and the loss is backpropagated to update the model. + Arguments: + - data (:obj:`Tuple[torch.Tensor]`): The data sampled from replay buffer, which is a tuple of tensors. + The first tensor is the current_batch, the second tensor is the target_batch. + Returns: + - info_dict (:obj:`Dict[str, Union[float, int]]`): The information dict to be logged, which contains \ + current learning loss and learning statistics. + """ + self._learn_model.train() + self._target_model.train() + + current_batch, target_batch = data + obs_batch_ori, action_batch, mask_batch, indices, weights, make_time = current_batch + target_reward, target_value, target_policy = target_batch + + + + obs_batch, obs_target_batch = prepare_obs(obs_batch_ori, self._cfg) + encoder_image_list = [] + encoder_image_list.append(obs_batch) + for zt in range(5): + beg_index = self._cfg.model.image_channel * zt + end_index = self._cfg.model.image_channel * (zt + self._cfg.model.frame_stack_num) + encoder_image_list.append(obs_target_batch[:, beg_index:end_index, :, :]) + + # do augmentations + if self._cfg.use_augmentation: + obs_batch = self.image_transforms.transform(obs_batch) + if self._cfg.model.self_supervised_learning_loss: + obs_target_batch = self.image_transforms.transform(obs_target_batch) + + # shape: (batch_size, num_unroll_steps, action_dim) + # NOTE: .long(), in discrete action space. + action_batch = torch.from_numpy(action_batch).to(self._cfg.device).unsqueeze(-1).long() + data_list = [ + mask_batch, + target_reward.astype('float64'), + target_value.astype('float64'), target_policy, weights + ] + [mask_batch, target_reward, target_value, target_policy, + weights] = to_torch_float_tensor(data_list, self._cfg.device) + + target_reward = target_reward.view(self._cfg.batch_size, -1) + target_value = target_value.view(self._cfg.batch_size, -1) + + assert obs_batch.size(0) == self._cfg.batch_size == target_reward.size(0) + + # ``scalar_transform`` to transform the original value to the scaled value, + # i.e. h(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. + transformed_target_reward = scalar_transform(target_reward) + transformed_target_value = scalar_transform(target_value) + + # transform a scalar to its categorical_distribution. After this transformation, each scalar is + # represented as the linear combination of its two adjacent supports. + target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward) + target_value_categorical = phi_transform(self.value_support, transformed_target_value) + + # ============================================================== + # the core initial_inference in MuZero policy. + # ============================================================== + network_output = self._learn_model.initial_inference(obs_batch) + + # value_prefix shape: (batch_size, 10), the ``value_prefix`` at the first step is zero padding. + latent_state, reward, value, policy_logits = mz_network_output_unpack(network_output) + + # transform the scaled value or its categorical representation to its original value, + # i.e. h^(-1)(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. + original_value = self.inverse_scalar_transform_handle(value) + + # Note: The following lines are just for debugging. + predicted_rewards = [] + if self._cfg.monitor_extra_statistics: + latent_state_list = latent_state.detach().cpu().numpy() + predicted_values, predicted_policies = original_value.detach().cpu(), torch.softmax( + policy_logits, dim=1 + ).detach().cpu() + + # calculate the new priorities for each transition. + value_priority = L1Loss(reduction='none')(original_value.squeeze(-1), target_value[:, 0]) + value_priority = value_priority.data.cpu().numpy() + 1e-6 + + # ============================================================== + # calculate policy and value loss for the first step. + # ============================================================== + policy_loss = cross_entropy_loss(policy_logits, target_policy[:, 0]) + value_loss = cross_entropy_loss(value, target_value_categorical[:, 0]) + + reward_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device) + consistency_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device) + + afterstate_policy_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device) + afterstate_value_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device) + commitment_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device) + + gradient_scale = 1 / self._cfg.num_unroll_steps + + # ============================================================== + # the core recurrent_inference in MuZero policy. + # ============================================================== + for step_i in range(self._cfg.num_unroll_steps): + # unroll with the dynamics function: predict the next ``latent_state``, ``reward``, + # given current ``latent_state`` and ``action``. + # And then predict policy_logits and value with the prediction function. + network_output = self._learn_model.recurrent_inference(latent_state, action_batch[:, step_i], afterstate=False) + after_state, a_reward, a_value, a_policy_logits = mz_network_output_unpack(network_output) + + former_frame = encoder_image_list[step_i] + latter_frame = encoder_image_list[step_i+1] + concat_frame = torch.cat((former_frame, latter_frame), dim=1) + + chance_code, encode_output = self._learn_model._encode_vqvae(concat_frame) + #chance_code, encode_output = self._learn_model._encode_vqvae(encoder_image_list[step_i]) + chance_code_long = torch.argmax(chance_code, dim=1).long().unsqueeze(-1) + + network_output = self._learn_model.recurrent_inference(after_state, chance_code_long, afterstate=True) + latent_state, reward, value, policy_logits = mz_network_output_unpack(network_output) + + # transform the scaled value or its categorical representation to its original value, + # i.e. h^(-1)(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. + original_value = self.inverse_scalar_transform_handle(value) + + if self._cfg.model.self_supervised_learning_loss: + # ============================================================== + # calculate consistency loss for the next ``num_unroll_steps`` unroll steps. + # ============================================================== + if self._cfg.ssl_loss_weight > 0: + # obtain the oracle hidden states from representation function. + if self._cfg.model.model_type == 'conv': + beg_index = self._cfg.model.image_channel * step_i + end_index = self._cfg.model.image_channel * (step_i + self._cfg.model.frame_stack_num) + network_output = self._learn_model.initial_inference( + obs_target_batch[:, beg_index:end_index, :, :] + ) + elif self._cfg.model.model_type == 'mlp': + beg_index = self._cfg.model.observation_shape * step_i + end_index = self._cfg.model.observation_shape * (step_i + self._cfg.model.frame_stack_num) + network_output = self._learn_model.initial_inference(obs_target_batch[:, beg_index:end_index]) + + latent_state = to_tensor(latent_state) + representation_state = to_tensor(network_output.latent_state) + + # NOTE: no grad for the representation_state branch + dynamic_proj = self._learn_model.project(latent_state, with_grad=True) + observation_proj = self._learn_model.project(representation_state, with_grad=False) + temp_loss = negative_cosine_similarity(dynamic_proj, observation_proj) * mask_batch[:, step_i] + consistency_loss += temp_loss + + # NOTE: the target policy, target_value_categorical, target_reward_categorical is calculated in + # game buffer now. + # ============================================================== + # calculate policy loss for the next ``num_unroll_steps`` unroll steps. + # NOTE: the +=. + # ============================================================== + policy_loss += cross_entropy_loss(policy_logits, target_policy[:, step_i + 1]) + afterstate_policy_loss += cross_entropy_loss(a_policy_logits, chance_code) + commitment_loss += cross_entropy_loss(encode_output, chance_code) + + afterstate_value_loss += cross_entropy_loss(a_value, target_value_categorical[:, step_i]) + value_loss += cross_entropy_loss(value, target_value_categorical[:, step_i + 1]) + reward_loss += cross_entropy_loss(reward, target_reward_categorical[:, step_i]) + + # Follow MuZero, set half gradient + # latent_state.register_hook(lambda grad: grad * 0.5) + + if self._cfg.monitor_extra_statistics: + original_rewards = self.inverse_scalar_transform_handle(reward) + original_rewards_cpu = original_rewards.detach().cpu() + + predicted_values = torch.cat( + (predicted_values, self.inverse_scalar_transform_handle(value).detach().cpu()) + ) + predicted_rewards.append(original_rewards_cpu) + predicted_policies = torch.cat((predicted_policies, torch.softmax(policy_logits, dim=1).detach().cpu())) + latent_state_list = np.concatenate((latent_state_list, latent_state.detach().cpu().numpy())) + + # ============================================================== + # the core learn model update step. + # ============================================================== + # weighted loss with masks (some invalid states which are out of trajectory.) + loss = ( + self._cfg.ssl_loss_weight * consistency_loss + self._cfg.policy_loss_weight * policy_loss + + self._cfg.value_loss_weight * value_loss + self._cfg.reward_loss_weight * reward_loss + + self._cfg.policy_loss_weight * afterstate_policy_loss + self._cfg.value_loss_weight * afterstate_value_loss + + self._cfg.policy_loss_weight * commitment_loss + + ) + weighted_total_loss = (weights * loss).mean() + + gradient_scale = 1 / self._cfg.num_unroll_steps + weighted_total_loss.register_hook(lambda grad: grad * gradient_scale) + self._optimizer.zero_grad() + weighted_total_loss.backward() + total_grad_norm_before_clip = torch.nn.utils.clip_grad_norm_( + self._learn_model.parameters(), self._cfg.grad_clip_value + ) + self._optimizer.step() + if self._cfg.lr_piecewise_constant_decay is True: + self.lr_scheduler.step() + + # ============================================================== + # the core target model update step. + # ============================================================== + self._target_model.update(self._learn_model.state_dict()) + + # packing loss info for tensorboard logging + loss_info = ( + weighted_total_loss.item(), loss.mean().item(), policy_loss.mean().item(), reward_loss.mean().item(), + value_loss.mean().item(), consistency_loss.mean(), afterstate_policy_loss.mean().item(), + afterstate_value_loss.mean().item(), commitment_loss.mean().item(), + + ) + if self._cfg.monitor_extra_statistics: + predicted_rewards = torch.stack(predicted_rewards).transpose(1, 0).squeeze(-1) + predicted_rewards = predicted_rewards.reshape(-1).unsqueeze(-1) + + td_data = ( + value_priority, + target_reward.detach().cpu().numpy(), + target_value.detach().cpu().numpy(), + transformed_target_reward.detach().cpu().numpy(), + transformed_target_value.detach().cpu().numpy(), + target_reward_categorical.detach().cpu().numpy(), + target_value_categorical.detach().cpu().numpy(), + predicted_rewards.detach().cpu().numpy(), + predicted_values.detach().cpu().numpy(), + target_policy.detach().cpu().numpy(), + predicted_policies.detach().cpu().numpy(), + latent_state_list, + ) + + return { + 'collect_mcts_temperature': self.collect_mcts_temperature, + 'cur_lr': self._optimizer.param_groups[0]['lr'], + 'weighted_total_loss': loss_info[0], + 'total_loss': loss_info[1], + 'policy_loss': loss_info[2], + 'reward_loss': loss_info[3], + 'value_loss': loss_info[4], + 'consistency_loss': loss_info[5] / self._cfg.num_unroll_steps, + 'afterstate_policy_loss': loss_info[6], + 'afterstate_value_loss': loss_info[7], + 'commitment_loss': loss_info[8], + + # ============================================================== + # priority related + # ============================================================== + 'value_priority_orig': value_priority, + 'value_priority': td_data[0].flatten().mean().item(), + 'target_reward': td_data[1].flatten().mean().item(), + 'target_value': td_data[2].flatten().mean().item(), + 'transformed_target_reward': td_data[3].flatten().mean().item(), + 'transformed_target_value': td_data[4].flatten().mean().item(), + 'predicted_rewards': td_data[7].flatten().mean().item(), + 'predicted_values': td_data[8].flatten().mean().item(), + 'total_grad_norm_before_clip': total_grad_norm_before_clip + } + + def _init_collect(self) -> None: + """ + Overview: + Collect mode init method. Called by ``self.__init__``. Ininitialize the collect model and MCTS utils. + """ + self._collect_model = self._model + if self._cfg.mcts_ctree: + self._mcts_collect = MCTSCtree(self._cfg) + else: + self._mcts_collect = MCTSPtree(self._cfg) + self.collect_mcts_temperature = 1 + + def _forward_collect( + self, + data: torch.Tensor, + action_mask: list = None, + temperature: float = 1, + to_play: List = [-1], + ready_env_id=None + ) -> Dict: + """ + Overview: + The forward function for collecting data in collect mode. Use model to execute MCTS search. + Choosing the action through sampling during the collect mode. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e. the observation. + - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - temperature (:obj:`float`): The temperature of the policy. + - to_play (:obj:`int`): The player to play. + - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + Shape: + - data (:obj:`torch.Tensor`): + - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ + S is the number of stacked frames, H is the height of the image, W is the width of the image. + - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. + - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. + - temperature: :math:`(1, )`. + - to_play: :math:`(N, 1)`, where N is the number of collect_env. + - ready_env_id: None + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ + ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + """ + self._collect_model.eval() + self.collect_mcts_temperature = temperature + active_collect_env_num = data.shape[0] + with torch.no_grad(): + # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} + network_output = self._collect_model.initial_inference(data) + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + + if not self._learn_model.training: + # if not in training, obtain the scalars of the value/reward + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + latent_state_roots = latent_state_roots.detach().cpu().numpy() + policy_logits = policy_logits.detach().cpu().numpy().tolist() + + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_collect_env_num)] + # the only difference between collect and eval is the dirichlet noise + noises = [ + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) + ).astype(np.float32).tolist() for j in range(active_collect_env_num) + ] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(active_collect_env_num, legal_actions) + else: + # python mcts_tree + roots = MCTSPtree.roots(active_collect_env_num, legal_actions) + + roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play) + self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play) + + roots_visit_count_distributions = roots.get_distributions( + ) # shape: ``{list: batch_size} ->{list: action_space_size}`` + roots_values = roots.get_values() # shape: {list: batch_size} + + data_id = [i for i in range(active_collect_env_num)] + output = {i: None for i in data_id} + + if ready_env_id is None: + ready_env_id = np.arange(active_collect_env_num) + + for i, env_id in enumerate(ready_env_id): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents + # the index within the legal action set, rather than the index in the entire action set. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self.collect_mcts_temperature, deterministic=False + ) + # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the + # entire action set. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + output[env_id] = { + 'action': action, + 'distributions': distributions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'value': value, + 'pred_value': pred_values[i], + 'policy_logits': policy_logits[i], + } + + return output + + def _init_eval(self) -> None: + """ + Overview: + Evaluate mode init method. Called by ``self.__init__``. Ininitialize the eval model and MCTS utils. + """ + self._eval_model = self._model + if self._cfg.mcts_ctree: + self._mcts_eval = MCTSCtree(self._cfg) + else: + self._mcts_eval = MCTSPtree(self._cfg) + + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, ready_env_id=None) -> Dict: + """ + Overview: + The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. + Choosing the action with the highest value (argmax) rather than sampling during the eval mode. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e. the observation. + - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - to_play (:obj:`int`): The player to play. + - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + Shape: + - data (:obj:`torch.Tensor`): + - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ + S is the number of stacked frames, H is the height of the image, W is the width of the image. + - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. + - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. + - to_play: :math:`(N, 1)`, where N is the number of collect_env. + - ready_env_id: None + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ + ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + """ + self._eval_model.eval() + active_eval_env_num = data.shape[0] + with torch.no_grad(): + # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} + network_output = self._collect_model.initial_inference(data) + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + + if not self._eval_model.training: + # if not in training, obtain the scalars of the value/reward + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) + latent_state_roots = latent_state_roots.detach().cpu().numpy() + policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A) + + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_eval_env_num)] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(active_eval_env_num, legal_actions) + else: + # python mcts_tree + roots = MCTSPtree.roots(active_eval_env_num, legal_actions) + roots.prepare_no_noise(reward_roots, policy_logits, to_play) + self._mcts_eval.search(roots, self._eval_model, latent_state_roots, to_play) + + roots_visit_count_distributions = roots.get_distributions( + ) # shape: ``{list: batch_size} ->{list: action_space_size}`` + roots_values = roots.get_values() # shape: {list: batch_size} + + data_id = [i for i in range(active_eval_env_num)] + output = {i: None for i in data_id} + + if ready_env_id is None: + ready_env_id = np.arange(active_eval_env_num) + + for i, env_id in enumerate(ready_env_id): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents + # the index within the legal action set, rather than the index in the entire action set. + # Setting deterministic=True implies choosing the action with the highest value (argmax) rather than + # sampling during the evaluation phase. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=1, deterministic=True + ) + # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the + # entire action set. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + + output[env_id] = { + 'action': action, + 'distributions': distributions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'value': value, + 'pred_value': pred_values[i], + 'policy_logits': policy_logits[i], + } + + return output + + def _monitor_vars_learn(self) -> List[str]: + """ + Overview: + Register the variables to be monitored in learn mode. The registered variables will be logged in + tensorboard according to the return value ``_forward_learn``. + """ + return [ + 'collect_mcts_temperature', + 'cur_lr', + 'weighted_total_loss', + 'total_loss', + 'policy_loss', + 'reward_loss', + 'value_loss', + 'consistency_loss', + 'afterstate_policy_loss', + 'commitment_loss', + 'afterstate_value_loss', + 'value_priority', + 'target_reward', + 'target_value', + 'predicted_rewards', + 'predicted_values', + 'transformed_target_reward', + 'transformed_target_value', + 'total_grad_norm_before_clip', + ] + + def _state_dict_learn(self) -> Dict[str, Any]: + """ + Overview: + Return the state_dict of learn mode, usually including model, target_model and optimizer. + Returns: + - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring. + """ + return { + 'model': self._learn_model.state_dict(), + 'target_model': self._target_model.state_dict(), + 'optimizer': self._optimizer.state_dict(), + } + + def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: + """ + Overview: + Load the state_dict variable into policy learn mode. + Arguments: + - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before. + """ + self._learn_model.load_state_dict(state_dict['model']) + self._target_model.load_state_dict(state_dict['target_model']) + self._optimizer.load_state_dict(state_dict['optimizer']) + + def _process_transition(self, obs, policy_output, timestep): + # be compatible with DI-engine Policy class + pass + + def _get_train_sample(self, data): + # be compatible with DI-engine Policy class + pass diff --git a/zoo/atari/config/atari_stochastic_muzero.py b/zoo/atari/config/atari_stochastic_muzero.py new file mode 100644 index 000000000..1fbefb784 --- /dev/null +++ b/zoo/atari/config/atari_stochastic_muzero.py @@ -0,0 +1,99 @@ +from easydict import EasyDict + +# options={'PongNoFrameskip-v4', 'QbertNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SpaceInvadersNoFrameskip-v4', 'BreakoutNoFrameskip-v4', ...} +env_name = 'PongNoFrameskip-v4' + +if env_name == 'PongNoFrameskip-v4': + action_space_size = 6 +elif env_name == 'QbertNoFrameskip-v4': + action_space_size = 6 +elif env_name == 'MsPacmanNoFrameskip-v4': + action_space_size = 9 +elif env_name == 'SpaceInvadersNoFrameskip-v4': + action_space_size = 6 +elif env_name == 'BreakoutNoFrameskip-v4': + action_space_size = 4 + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 1 +n_episode = 1 +evaluator_env_num = 1 +num_simulations = 5 +update_per_collect = 10 +batch_size = 25 +max_env_step = int(1e6) +reanalyze_ratio = 0. +chance_space_size = 2 +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +atari_stochastic_muzero_config = dict( + exp_name= + f'data_stochastic_mz_ctree/{env_name[:-14]}_stochastic_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_chance{chance_space_size}_seed0', + env=dict( + stop_value=int(1e6), + env_name=env_name, + obs_shape=(4, 96, 96), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + observation_shape=(4, 96, 96), + frame_stack_num=4, + action_space_size=action_space_size, + chance_space_size=chance_space_size, + downsample=True, + self_supervised_learning_loss=True, # default is False + discrete_action_encoding_type='one_hot', + norm_type='BN', + ), + cuda=True, + mcts_ctree=False, + env_type='not_board_games', + game_segment_length=400, + use_augmentation=True, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='SGD', + lr_piecewise_constant_decay=True, + learning_rate=0.2, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + ssl_loss_weight=2, # default is 0 + n_episode=n_episode, + eval_freq=int(2e3), + replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), +) +atari_stochastic_muzero_config = EasyDict(atari_stochastic_muzero_config) +main_config = atari_stochastic_muzero_config + +atari_stochastic_muzero_create_config = dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='stochastic_muzero', + import_names=['lzero.policy.stochastic_muzero'], + ), + collector=dict( + type='episode_muzero', + import_names=['lzero.worker.muzero_collector'], + ) +) +atari_stochastic_muzero_create_config = EasyDict(atari_stochastic_muzero_create_config) +create_config = atari_stochastic_muzero_create_config + +if __name__ == "__main__": + from lzero.entry import train_muzero + train_muzero([main_config, create_config], seed=0, max_env_step=max_env_step) From 3799b46d8c768abebb1d9b59d383534b1ab3cea1 Mon Sep 17 00:00:00 2001 From: timothijoe Date: Sat, 10 Jun 2023 16:59:29 +0800 Subject: [PATCH 02/28] add stochastic mz ctree --- .../buffer/game_buffer_stochastic_muzero.py | 2 +- .../ctree/ctree_stochastic_muzero/__init__.py | 0 .../ctree_stochastic_muzero/lib/cnode.cpp | 794 ++++++++++++++++++ .../ctree/ctree_stochastic_muzero/lib/cnode.h | 95 +++ .../stochastic_mz_tree.pxd | 74 ++ .../stochastic_mz_tree.pyx | 91 ++ lzero/mcts/tree_search/__init__.py | 1 + .../mcts/tree_search/mcts_ctree_stochastic.py | 237 ++++++ lzero/policy/stochastic_muzero.py | 2 +- zoo/atari/config/atari_stochastic_muzero.py | 3 +- 10 files changed, 1296 insertions(+), 3 deletions(-) create mode 100644 lzero/mcts/ctree/ctree_stochastic_muzero/__init__.py create mode 100644 lzero/mcts/ctree/ctree_stochastic_muzero/lib/cnode.cpp create mode 100644 lzero/mcts/ctree/ctree_stochastic_muzero/lib/cnode.h create mode 100644 lzero/mcts/ctree/ctree_stochastic_muzero/stochastic_mz_tree.pxd create mode 100644 lzero/mcts/ctree/ctree_stochastic_muzero/stochastic_mz_tree.pyx create mode 100644 lzero/mcts/tree_search/mcts_ctree_stochastic.py diff --git a/lzero/mcts/buffer/game_buffer_stochastic_muzero.py b/lzero/mcts/buffer/game_buffer_stochastic_muzero.py index 6dfe6d8e2..28c978a3a 100644 --- a/lzero/mcts/buffer/game_buffer_stochastic_muzero.py +++ b/lzero/mcts/buffer/game_buffer_stochastic_muzero.py @@ -4,7 +4,7 @@ import torch from ding.utils import BUFFER_REGISTRY -from lzero.mcts.tree_search.mcts_ctree import MuZeroMCTSCtree as MCTSCtree +from lzero.mcts.tree_search.mcts_ctree_stochastic import StochasticMuZeroMCTSCtree as MCTSCtree from lzero.mcts.tree_search.mcts_ptree_stochastic import StochasticMuZeroMCTSPtree as MCTSPtree from lzero.mcts.utils import prepare_observation from lzero.policy import to_detach_cpu_numpy, concat_output, concat_output_value, inverse_scalar_transform diff --git a/lzero/mcts/ctree/ctree_stochastic_muzero/__init__.py b/lzero/mcts/ctree/ctree_stochastic_muzero/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lzero/mcts/ctree/ctree_stochastic_muzero/lib/cnode.cpp b/lzero/mcts/ctree/ctree_stochastic_muzero/lib/cnode.cpp new file mode 100644 index 000000000..cf3c8d1e2 --- /dev/null +++ b/lzero/mcts/ctree/ctree_stochastic_muzero/lib/cnode.cpp @@ -0,0 +1,794 @@ +// C++11 + +#include +#include "cnode.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef _WIN32 +#include "..\..\common_lib\utils.cpp" +#else +#include "../../common_lib/utils.cpp" +#endif + + +namespace tree +{ + + CSearchResults::CSearchResults() + { + /* + Overview: + Initialization of CSearchResults, the default result number is set to 0. + */ + this->num = 0; + } + + CSearchResults::CSearchResults(int num) + { + /* + Overview: + Initialization of CSearchResults with result number. + */ + this->num = num; + for (int i = 0; i < num; ++i) + { + this->search_paths.push_back(std::vector()); + } + } + + CSearchResults::~CSearchResults() {} + + //********************************************************* + + CNode::CNode() + { + /* + Overview: + Initialization of CNode. + */ + this->prior = 0; + this->legal_actions = legal_actions; + + this->visit_count = 0; + this->value_sum = 0; + this->best_action = -1; + this->to_play = 0; + this->reward = 0.0; + this->is_chance = false; + this->chance_space_size= 2; + + } + + CNode::CNode(float prior, std::vector &legal_actions, bool is_chance, int chance_space_size) + { + /* + Overview: + Initialization of CNode with prior value and legal actions. + Arguments: + - prior: the prior value of this node. + - legal_actions: a vector of legal actions of this node. + */ + this->prior = prior; + this->legal_actions = legal_actions; + + this->visit_count = 0; + this->value_sum = 0; + this->best_action = -1; + this->to_play = 0; + this->current_latent_state_index = -1; + this->batch_index = -1; + this->is_chance = is_chance; + this->chance_space_size = chance_space_size; + } + + CNode::~CNode() {} + + void CNode::expand(int to_play, int current_latent_state_index, int batch_index, float reward, const std::vector &policy_logits, bool child_is_chance) + { + /* + Overview: + Expand the child nodes of the current node. + Arguments: + - to_play: which player to play the game in the current node. + - current_latent_state_index: The index of latent state of the leaf node in the search path of the current node. + - batch_index: The index of latent state of the leaf node in the search path of the current node. + - reward: the reward of the current node. + - policy_logits: the logit of the child nodes. + */ + this->to_play = to_play; + this->current_latent_state_index = current_latent_state_index; + this->batch_index = batch_index; + this->reward = reward; + + + // assert((this->is_chance != child_is_chance) && "is_chance and child_is_chance should be different"); + + if(this->is_chance == true){ + child_is_chance = false; + this->reward = 0.0; + } + else{ + child_is_chance = true; + } + + int action_num = policy_logits.size(); + if (this->legal_actions.size() == 0) + { + for (int i = 0; i < action_num; ++i) + { + this->legal_actions.push_back(i); + } + } + + float temp_policy; + float policy_sum = 0.0; + + #ifdef _WIN32 + // 创建动态数组 + float* policy = new float[action_num]; + #else + float policy[action_num]; + #endif + + float policy_max = FLOAT_MIN; + for (auto a : this->legal_actions) + { + if (policy_max < policy_logits[a]) + { + policy_max = policy_logits[a]; + } + } + + for (auto a : this->legal_actions) + { + temp_policy = exp(policy_logits[a] - policy_max); + policy_sum += temp_policy; + policy[a] = temp_policy; + } + + float prior; + for (auto a : this->legal_actions) + { + prior = policy[a] / policy_sum; + std::vector tmp_empty; + this->children[a] = CNode(prior, tmp_empty, child_is_chance, this->chance_space_size); // only for muzero/efficient zero, not support alphazero + // this->children[a] = CNode(prior, tmp_empty, is_chance = child_is_chance); // only for muzero/efficient zero, not support alphazero + } + + #ifdef _WIN32 + // 释放数组内存 + delete[] policy; + #else + #endif + } + + void CNode::add_exploration_noise(float exploration_fraction, const std::vector &noises) + { + /* + Overview: + Add a noise to the prior of the child nodes. + Arguments: + - exploration_fraction: the fraction to add noise. + - noises: the vector of noises added to each child node. + */ + float noise, prior; + for (int i = 0; i < this->legal_actions.size(); ++i) + { + noise = noises[i]; + CNode *child = this->get_child(this->legal_actions[i]); + + prior = child->prior; + child->prior = prior * (1 - exploration_fraction) + noise * exploration_fraction; + } + } + + float CNode::compute_mean_q(int isRoot, float parent_q, float discount_factor) + { + /* + Overview: + Compute the mean q value of the current node. + Arguments: + - isRoot: whether the current node is a root node. + - parent_q: the q value of the parent node. + - discount_factor: the discount_factor of reward. + */ + float total_unsigned_q = 0.0; + int total_visits = 0; + for (auto a : this->legal_actions) + { + CNode *child = this->get_child(a); + if (child->visit_count > 0) + { + float true_reward = child->reward; + float qsa = true_reward + discount_factor * child->value(); + total_unsigned_q += qsa; + total_visits += 1; + } + } + + float mean_q = 0.0; + if (isRoot && total_visits > 0) + { + mean_q = (total_unsigned_q) / (total_visits); + } + else + { + mean_q = (parent_q + total_unsigned_q) / (total_visits + 1); + } + return mean_q; + } + + void CNode::print_out() + { + return; + } + + int CNode::expanded() + { + /* + Overview: + Return whether the current node is expanded. + */ + return this->children.size() > 0; + } + + float CNode::value() + { + /* + Overview: + Return the real value of the current tree. + */ + float true_value = 0.0; + if (this->visit_count == 0) + { + return true_value; + } + else + { + true_value = this->value_sum / this->visit_count; + return true_value; + } + } + + std::vector CNode::get_trajectory() + { + /* + Overview: + Find the current best trajectory starts from the current node. + Outputs: + - traj: a vector of node index, which is the current best trajectory from this node. + */ + std::vector traj; + + CNode *node = this; + int best_action = node->best_action; + while (best_action >= 0) + { + traj.push_back(best_action); + + node = node->get_child(best_action); + best_action = node->best_action; + } + return traj; + } + + std::vector CNode::get_children_distribution() + { + /* + Overview: + Get the distribution of child nodes in the format of visit_count. + Outputs: + - distribution: a vector of distribution of child nodes in the format of visit count (i.e. [1,3,0,2,5]). + */ + std::vector distribution; + if (this->expanded()) + { + for (auto a : this->legal_actions) + { + CNode *child = this->get_child(a); + distribution.push_back(child->visit_count); + } + } + return distribution; + } + + CNode *CNode::get_child(int action) + { + /* + Overview: + Get the child node corresponding to the input action. + Arguments: + - action: the action to get child. + */ + return &(this->children[action]); + } + + //********************************************************* + + CRoots::CRoots() + { + /* + Overview: + The initialization of CRoots. + */ + this->root_num = 0; + } + + CRoots::CRoots(int root_num, std::vector > &legal_actions_list, int chance_space_size=2) + { + /* + Overview: + The initialization of CRoots with root num and legal action lists. + Arguments: + - root_num: the number of the current root. + - legal_action_list: the vector of the legal action of this root. + */ + this->root_num = root_num; + this->legal_actions_list = legal_actions_list; + + for (int i = 0; i < root_num; ++i) + { + this->roots.push_back(CNode(0, this->legal_actions_list[i], false, chance_space_size)); + // this->roots.push_back(CNode(0, this->legal_actions_list[i], false)); + + } + } + + CRoots::~CRoots() {} + + void CRoots::prepare(float root_noise_weight, const std::vector > &noises, const std::vector &rewards, const std::vector > &policies, std::vector &to_play_batch) + { + /* + Overview: + Expand the roots and add noises. + Arguments: + - root_noise_weight: the exploration fraction of roots + - noises: the vector of noise add to the roots. + - rewards: the vector of rewards of each root. + - policies: the vector of policy logits of each root. + - to_play_batch: the vector of the player side of each root. + */ + for (int i = 0; i < this->root_num; ++i) + { + this->roots[i].expand(to_play_batch[i], 0, i, rewards[i], policies[i], true); + this->roots[i].add_exploration_noise(root_noise_weight, noises[i]); + + this->roots[i].visit_count += 1; + } + } + + void CRoots::prepare_no_noise(const std::vector &rewards, const std::vector > &policies, std::vector &to_play_batch) + { + /* + Overview: + Expand the roots without noise. + Arguments: + - rewards: the vector of rewards of each root. + - policies: the vector of policy logits of each root. + - to_play_batch: the vector of the player side of each root. + */ + for (int i = 0; i < this->root_num; ++i) + { + this->roots[i].expand(to_play_batch[i], 0, i, rewards[i], policies[i], true); + + this->roots[i].visit_count += 1; + } + } + + void CRoots::clear() + { + /* + Overview: + Clear the roots vector. + */ + this->roots.clear(); + } + + std::vector > CRoots::get_trajectories() + { + /* + Overview: + Find the current best trajectory starts from each root. + Outputs: + - traj: a vector of node index, which is the current best trajectory from each root. + */ + std::vector > trajs; + trajs.reserve(this->root_num); + + for (int i = 0; i < this->root_num; ++i) + { + trajs.push_back(this->roots[i].get_trajectory()); + } + return trajs; + } + + std::vector > CRoots::get_distributions() + { + /* + Overview: + Get the children distribution of each root. + Outputs: + - distribution: a vector of distribution of child nodes in the format of visit count (i.e. [1,3,0,2,5]). + */ + std::vector > distributions; + distributions.reserve(this->root_num); + + for (int i = 0; i < this->root_num; ++i) + { + distributions.push_back(this->roots[i].get_children_distribution()); + } + return distributions; + } + + std::vector CRoots::get_values() + { + /* + Overview: + Return the real value of each root. + */ + std::vector values; + for (int i = 0; i < this->root_num; ++i) + { + values.push_back(this->roots[i].value()); + } + return values; + } + + //********************************************************* + // + void update_tree_q(CNode *root, tools::CMinMaxStats &min_max_stats, float discount_factor, int players) + { + /* + Overview: + Update the q value of the root and its child nodes. + Arguments: + - root: the root that update q value from. + - min_max_stats: a tool used to min-max normalize the q value. + - discount_factor: the discount factor of reward. + - players: the number of players. + */ + std::stack node_stack; + node_stack.push(root); + while (node_stack.size() > 0) + { + CNode *node = node_stack.top(); + node_stack.pop(); + + if (node != root) + { + // # NOTE: in self-play-mode, value_prefix is not calculated according to the perspective of current player of node, + // # but treated as 1 player, just for obtaining the true reward in the perspective of current player of node. + // # true_reward = node.value_prefix - (- parent_value_prefix) + // float true_reward = node->value_prefix - node->parent_value_prefix; + float true_reward = node->reward; + + float qsa; + if (players == 1) + qsa = true_reward + discount_factor * node->value(); + else if (players == 2) + // TODO(pu): + qsa = true_reward + discount_factor * (-1) * node->value(); + + min_max_stats.update(qsa); + } + + for (auto a : node->legal_actions) + { + CNode *child = node->get_child(a); + if (child->expanded()) + { + node_stack.push(child); + } + } + } + } + + void cbackpropagate(std::vector &search_path, tools::CMinMaxStats &min_max_stats, int to_play, float value, float discount_factor) + { + /* + Overview: + Update the value sum and visit count of nodes along the search path. + Arguments: + - search_path: a vector of nodes on the search path. + - min_max_stats: a tool used to min-max normalize the q value. + - to_play: which player to play the game in the current node. + - value: the value to propagate along the search path. + - discount_factor: the discount factor of reward. + */ + assert(to_play == -1 || to_play == 1 || to_play == 2); + if (to_play == -1) + { + // for play-with-bot-mode + float bootstrap_value = value; + int path_len = search_path.size(); + for (int i = path_len - 1; i >= 0; --i) + { + CNode *node = search_path[i]; + node->value_sum += bootstrap_value; + node->visit_count += 1; + + float true_reward = node->reward; + + min_max_stats.update(true_reward + discount_factor * node->value()); + + bootstrap_value = true_reward + discount_factor * bootstrap_value; + // std::cout << "to_play: " << to_play << std::endl; + + } + } + else + { + // for self-play-mode + float bootstrap_value = value; + int path_len = search_path.size(); + for (int i = path_len - 1; i >= 0; --i) + { + CNode *node = search_path[i]; + if (node->to_play == to_play) + node->value_sum += bootstrap_value; + else + node->value_sum += -bootstrap_value; + node->visit_count += 1; + + // NOTE: in self-play-mode, value_prefix is not calculated according to the perspective of current player of node, + // but treated as 1 player, just for obtaining the true reward in the perspective of current player of node. + // float true_reward = node->value_prefix - parent_value_prefix; + float true_reward = node->reward; + + // TODO(pu): why in muzero-general is - node.value + min_max_stats.update(true_reward + discount_factor * -node->value()); + + if (node->to_play == to_play) + bootstrap_value = -true_reward + discount_factor * bootstrap_value; + else + bootstrap_value = true_reward + discount_factor * bootstrap_value; + } + } + } + + void cbatch_backpropagate(int current_latent_state_index, float discount_factor, const std::vector &value_prefixs, const std::vector &values, const std::vector > &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector &to_play_batch, std::vector &is_chance_list, std::vector &leaf_idx_list) + { + /* + Overview: + Expand the nodes along the search path and update the infos. + Arguments: + - current_latent_state_index: The index of latent state of the leaf node in the search path. + - discount_factor: the discount factor of reward. + - value_prefixs: the value prefixs of nodes along the search path. + - values: the values to propagate along the search path. + - policies: the policy logits of nodes along the search path. + - min_max_stats: a tool used to min-max normalize the q value. + - results: the search results. + - to_play_batch: the batch of which player is playing on this node. + */ + + if (leaf_idx_list.empty()) { + leaf_idx_list.resize(results.num); + for (int i = 0; i < results.num; ++i) { + leaf_idx_list[i] = i; + } + } + + for (auto leaf_order = 0; leaf_order < leaf_idx_list.size(); ++leaf_order) { + int i = leaf_idx_list[leaf_order]; + // Your code here + } + for (int leaf_order = 0; leaf_order < leaf_idx_list.size(); ++leaf_order) + { + int i = leaf_idx_list[leaf_order]; + results.nodes[i]->expand(to_play_batch[i], current_latent_state_index, i, value_prefixs[leaf_order], policies[leaf_order], is_chance_list[i]); + cbackpropagate(results.search_paths[i], min_max_stats_lst->stats_lst[i], to_play_batch[i], values[leaf_order], discount_factor); + } + + + // for (int i = 0; i < results.num; ++i) + // { + // results.nodes[i]->expand(to_play_batch[i], current_latent_state_index, i, value_prefixs[i], policies[i], is_chance_list[i]); + // cbackpropagate(results.search_paths[i], min_max_stats_lst->stats_lst[i], to_play_batch[i], values[i], discount_factor); + // } + } + + int cselect_child(CNode *root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount_factor, float mean_q, int players) + { + /* + Overview: + Select the child node of the roots according to ucb scores. + Arguments: + - root: the roots to select the child node. + - min_max_stats: a tool used to min-max normalize the score. + - pb_c_base: constants c2 in muzero. + - pb_c_init: constants c1 in muzero. + - disount_factor: the discount factor of reward. + - mean_q: the mean q value of the parent node. + - players: the number of players. + Outputs: + - action: the action to select. + */ + if (root->is_chance) { + // std::cout << "root->is_chance: True " << std::endl; + + // If the node is a chance node, we sample from the prior outcome distribution. + std::vector outcomes; + std::vector probs; + + for (const auto& kv : root->children) { + outcomes.push_back(kv.first); + probs.push_back(kv.second.prior); // Assuming 'prior' is a member variable of Node + } + + std::random_device rd; + std::mt19937 gen(rd()); + std::discrete_distribution<> dist(probs.begin(), probs.end()); + + int outcome = outcomes[dist(gen)]; + // std::cout << "Outcome: " << outcome << std::endl; + + return outcome; + } + + // std::cout << "root->is_chance: False " << std::endl; + + float max_score = FLOAT_MIN; + const float epsilon = 0.000001; + std::vector max_index_lst; + for (auto a : root->legal_actions) + { + + CNode *child = root->get_child(a); + float temp_score = cucb_score(child, min_max_stats, mean_q, root->visit_count - 1, pb_c_base, pb_c_init, discount_factor, players); + + if (max_score < temp_score) + { + max_score = temp_score; + + max_index_lst.clear(); + max_index_lst.push_back(a); + } + else if (temp_score >= max_score - epsilon) + { + max_index_lst.push_back(a); + } + } + + int action = 0; + if (max_index_lst.size() > 0) + { + int rand_index = rand() % max_index_lst.size(); + action = max_index_lst[rand_index]; + } + return action; + } + + float cucb_score(CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, float total_children_visit_counts, float pb_c_base, float pb_c_init, float discount_factor, int players) + { + /* + Overview: + Compute the ucb score of the child. + Arguments: + - child: the child node to compute ucb score. + - min_max_stats: a tool used to min-max normalize the score. + - mean_q: the mean q value of the parent node. + - total_children_visit_counts: the total visit counts of the child nodes of the parent node. + - pb_c_base: constants c2 in muzero. + - pb_c_init: constants c1 in muzero. + - disount_factor: the discount factor of reward. + - players: the number of players. + Outputs: + - ucb_value: the ucb score of the child. + */ + float pb_c = 0.0, prior_score = 0.0, value_score = 0.0; + pb_c = log((total_children_visit_counts + pb_c_base + 1) / pb_c_base) + pb_c_init; + pb_c *= (sqrt(total_children_visit_counts) / (child->visit_count + 1)); + + prior_score = pb_c * child->prior; + if (child->visit_count == 0) + { + value_score = parent_mean_q; + } + else + { + float true_reward = child->reward; + if (players == 1) + value_score = true_reward + discount_factor * child->value(); + else if (players == 2) + value_score = true_reward + discount_factor * (-child->value()); + } + + value_score = min_max_stats.normalize(value_score); + + if (value_score < 0) + value_score = 0; + if (value_score > 1) + value_score = 1; + + float ucb_value = prior_score + value_score; + return ucb_value; + } + + void cbatch_traverse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector &virtual_to_play_batch) + { + /* + Overview: + Search node path from the roots. + Arguments: + - roots: the roots that search from. + - pb_c_base: constants c2 in muzero. + - pb_c_init: constants c1 in muzero. + - disount_factor: the discount factor of reward. + - min_max_stats: a tool used to min-max normalize the score. + - results: the search results. + - virtual_to_play_batch: the batch of which player is playing on this node. + */ + // set seed + get_time_and_set_rand_seed(); + + int last_action = -1; + float parent_q = 0.0; + results.search_lens = std::vector(); + + int players = 0; + int largest_element = *max_element(virtual_to_play_batch.begin(), virtual_to_play_batch.end()); // 0 or 2 + if (largest_element == -1) + players = 1; + else + players = 2; + + for (int i = 0; i < results.num; ++i) + { + CNode *node = &(roots->roots[i]); + int is_root = 1; + int search_len = 0; + results.search_paths[i].push_back(node); + + // std::cout << "root->is_chance: " <is_chance<< std::endl; + // node->is_chance=false; + + while (node->expanded()) + { + float mean_q = node->compute_mean_q(is_root, parent_q, discount_factor); + is_root = 0; + parent_q = mean_q; + // std::cout << "node->is_chance: " <is_chance<< std::endl; + + int action = cselect_child(node, min_max_stats_lst->stats_lst[i], pb_c_base, pb_c_init, discount_factor, mean_q, players); + if (players > 1) + { + assert(virtual_to_play_batch[i] == 1 || virtual_to_play_batch[i] == 2); + if (virtual_to_play_batch[i] == 1) + virtual_to_play_batch[i] = 2; + else + virtual_to_play_batch[i] = 1; + } + + node->best_action = action; + // next + node = node->get_child(action); + last_action = action; + results.search_paths[i].push_back(node); + search_len += 1; + } + + CNode *parent = results.search_paths[i][results.search_paths[i].size() - 2]; + + results.latent_state_index_in_search_path.push_back(parent->current_latent_state_index); + results.latent_state_index_in_batch.push_back(parent->batch_index); + + results.last_actions.push_back(last_action); + results.search_lens.push_back(search_len); + results.nodes.push_back(node); + results.leaf_node_is_chance.push_back(node->is_chance); + results.virtual_to_play_batchs.push_back(virtual_to_play_batch[i]); + + } + } + +} \ No newline at end of file diff --git a/lzero/mcts/ctree/ctree_stochastic_muzero/lib/cnode.h b/lzero/mcts/ctree/ctree_stochastic_muzero/lib/cnode.h new file mode 100644 index 000000000..b3fae2997 --- /dev/null +++ b/lzero/mcts/ctree/ctree_stochastic_muzero/lib/cnode.h @@ -0,0 +1,95 @@ +// C++11 + +#ifndef CNODE_H +#define CNODE_H + +#include "./../common_lib/cminimax.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +const int DEBUG_MODE = 0; + +namespace tree { + + class CNode { + public: + int visit_count, to_play, current_latent_state_index, batch_index, best_action; + float reward, prior, value_sum; + bool is_chance; + int chance_space_size; + std::vector children_index; + std::map children; + + std::vector legal_actions; + + CNode(); + CNode(float prior, std::vector &legal_actions, bool is_chance = false, int chance_space_size = 2); + ~CNode(); + + void expand(int to_play, int current_latent_state_index, int batch_index, float reward, const std::vector &policy_logits, bool is_chance); + void add_exploration_noise(float exploration_fraction, const std::vector &noises); + float compute_mean_q(int isRoot, float parent_q, float discount_factor); + void print_out(); + + int expanded(); + + float value(); + + std::vector get_trajectory(); + std::vector get_children_distribution(); + CNode* get_child(int action); + }; + + class CRoots{ + public: + int root_num; + std::vector roots; + std::vector > legal_actions_list; + int chance_space_size; + + CRoots(); + CRoots(int root_num, std::vector > &legal_actions_list, int chance_space_size); + ~CRoots(); + + void prepare(float root_noise_weight, const std::vector > &noises, const std::vector &rewards, const std::vector > &policies, std::vector &to_play_batch); + void prepare_no_noise(const std::vector &rewards, const std::vector > &policies, std::vector &to_play_batch); + void clear(); + std::vector > get_trajectories(); + std::vector > get_distributions(); + std::vector get_values(); + + }; + + class CSearchResults{ + public: + int num; + std::vector latent_state_index_in_search_path, latent_state_index_in_batch, last_actions, search_lens; + std::vector virtual_to_play_batchs; + std::vector nodes; + std::vector leaf_node_is_chance; + std::vector > search_paths; + + CSearchResults(); + CSearchResults(int num); + ~CSearchResults(); + + }; + + + //********************************************************* + void update_tree_q(CNode* root, tools::CMinMaxStats &min_max_stats, float discount_factor, int players); + void cbackpropagate(std::vector &search_path, tools::CMinMaxStats &min_max_stats, int to_play, float value, float discount_factor); + void cbatch_backpropagate(int current_latent_state_index, float discount_factor, const std::vector &rewards, const std::vector &values, const std::vector > &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector &to_play_batch, std::vector & is_chance_list, std::vector &leaf_idx_list); + int cselect_child(CNode* root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount_factor, float mean_q, int players); + float cucb_score(CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, float total_children_visit_counts, float pb_c_base, float pb_c_init, float discount_factor, int players); + void cbatch_traverse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector &virtual_to_play_batch); +} + +#endif \ No newline at end of file diff --git a/lzero/mcts/ctree/ctree_stochastic_muzero/stochastic_mz_tree.pxd b/lzero/mcts/ctree/ctree_stochastic_muzero/stochastic_mz_tree.pxd new file mode 100644 index 000000000..a24f895d5 --- /dev/null +++ b/lzero/mcts/ctree/ctree_stochastic_muzero/stochastic_mz_tree.pxd @@ -0,0 +1,74 @@ +# distutils:language=c++ +# cython:language_level=3 +from libcpp.vector cimport vector +from libcpp cimport bool + +cdef extern from "../common_lib/cminimax.cpp": + pass + + +cdef extern from "../common_lib/cminimax.h" namespace "tools": + cdef cppclass CMinMaxStats: + CMinMaxStats() except + + float maximum, minimum, value_delta_max + + void set_delta(float value_delta_max) + void update(float value) + void clear() + float normalize(float value) + + cdef cppclass CMinMaxStatsList: + CMinMaxStatsList() except + + CMinMaxStatsList(int num) except + + int num + vector[CMinMaxStats] stats_lst + + void set_delta(float value_delta_max) + +cdef extern from "lib/cnode.cpp": + pass + + +cdef extern from "lib/cnode.h" namespace "tree": + cdef cppclass CNode: + CNode() except + + CNode(float prior, vector[int] &legal_actions, bool is_chance, int chance_space_size) except + + int visit_count, to_play, current_latent_state_index, batch_index, best_action + float value_prefixs, prior, value_sum, parent_value_prefix + + void expand(int to_play, int current_latent_state_index, int batch_index, float value_prefixs, vector[float] policy_logits, bool is_chance) + void add_exploration_noise(float exploration_fraction, vector[float] noises) + float compute_mean_q(int isRoot, float parent_q, float discount_factor) + + int expanded() + float value() + vector[int] get_trajectory() + vector[int] get_children_distribution() + CNode* get_child(int action) + + cdef cppclass CRoots: + CRoots() except + + CRoots(int root_num, vector[vector[int]] legal_actions_list, int chance_space_size) except + + int root_num, chance_space_size + vector[CNode] roots + + void prepare(float root_noise_weight, const vector[vector[float]] &noises, const vector[float] &value_prefixs, const vector[vector[float]] &policies, vector[int] to_play_batch) + void prepare_no_noise(const vector[float] &value_prefixs, const vector[vector[float]] &policies, vector[int] to_play_batch) + void clear() + vector[vector[int]] get_trajectories() + vector[vector[int]] get_distributions() + vector[float] get_values() + + cdef cppclass CSearchResults: + CSearchResults() except + + CSearchResults(int num) except + + int num + vector[int] latent_state_index_in_search_path, latent_state_index_in_batch, last_actions, search_lens + vector[int] virtual_to_play_batchs + vector[bool] leaf_node_is_chance + vector[CNode*] nodes + + cdef void cbackpropagate(vector[CNode*] &search_path, CMinMaxStats &min_max_stats, int to_play, float value, float discount_factor) + void cbatch_backpropagate(int current_latent_state_index, float discount_factor, vector[float] value_prefixs, vector[float] values, vector[vector[float]] policies, + CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, vector[int] &to_play_batch, vector[bool] &is_chance_list, vector[int] &leaf_idx_list) + void cbatch_traverse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, vector[int] &virtual_to_play_batch) diff --git a/lzero/mcts/ctree/ctree_stochastic_muzero/stochastic_mz_tree.pyx b/lzero/mcts/ctree/ctree_stochastic_muzero/stochastic_mz_tree.pyx new file mode 100644 index 000000000..74453c964 --- /dev/null +++ b/lzero/mcts/ctree/ctree_stochastic_muzero/stochastic_mz_tree.pyx @@ -0,0 +1,91 @@ +# distutils: language=c++ +# cython:language_level=3 +from libcpp.vector cimport vector +from libcpp cimport bool + +cdef class MinMaxStatsList: + cdef CMinMaxStatsList *cmin_max_stats_lst + + def __cinit__(self, int num): + self.cmin_max_stats_lst = new CMinMaxStatsList(num) + + def set_delta(self, float value_delta_max): + self.cmin_max_stats_lst[0].set_delta(value_delta_max) + + def __dealloc__(self): + del self.cmin_max_stats_lst + +cdef class ResultsWrapper: + cdef CSearchResults cresults + + def __cinit__(self, int num): + self.cresults = CSearchResults(num) + + def get_search_len(self): + return self.cresults.search_lens + +cdef class Roots: + cdef int root_num + cdef CRoots *roots + + def __cinit__(self, int root_num, vector[vector[int]] legal_actions_list, int chance_space_size): + self.root_num = root_num + self.roots = new CRoots(root_num, legal_actions_list, chance_space_size) + + def prepare(self, float root_noise_weight, list noises, list value_prefix_pool, list policy_logits_pool, + vector[int] & to_play_batch): + self.roots[0].prepare(root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play_batch) + + def prepare_no_noise(self, list value_prefix_pool, list policy_logits_pool, vector[int] & to_play_batch): + self.roots[0].prepare_no_noise(value_prefix_pool, policy_logits_pool, to_play_batch) + + def get_trajectories(self): + return self.roots[0].get_trajectories() + + def get_distributions(self): + return self.roots[0].get_distributions() + + def get_values(self): + return self.roots[0].get_values() + + def clear(self): + self.roots[0].clear() + + def __dealloc__(self): + del self.roots + + @property + def num(self): + return self.root_num + +cdef class Node: + cdef CNode cnode + + def __cinit__(self): + pass + + def __cinit__(self, float prior, vector[int] & legal_actions, bool is_chance, int chance_space_size): + pass + + def expand(self, int to_play, int current_latent_state_index, int batch_index, float value_prefix, + list policy_logits, bool is_chance): + cdef vector[float] cpolicy = policy_logits + self.cnode.expand(to_play, current_latent_state_index, batch_index, value_prefix, cpolicy, is_chance) + +def batch_backpropagate(int current_latent_state_index, float discount_factor, list value_prefixs, list values, list policies, + MinMaxStatsList min_max_stats_lst, ResultsWrapper results, list to_play_batch, list is_chance_list, list leaf_idx_list): + cdef int i + cdef vector[float] cvalue_prefixs = value_prefixs + cdef vector[float] cvalues = values + cdef vector[vector[float]] cpolicies = policies + + cbatch_backpropagate(current_latent_state_index, discount_factor, cvalue_prefixs, cvalues, cpolicies, + min_max_stats_lst.cmin_max_stats_lst, results.cresults, to_play_batch, is_chance_list, leaf_idx_list) + +def batch_traverse(Roots roots, int pb_c_base, float pb_c_init, float discount_factor, MinMaxStatsList min_max_stats_lst, + ResultsWrapper results, list virtual_to_play_batch): + cbatch_traverse(roots.roots, pb_c_base, pb_c_init, discount_factor, min_max_stats_lst.cmin_max_stats_lst, results.cresults, + virtual_to_play_batch) + + return results.cresults.leaf_node_is_chance, results.cresults.latent_state_index_in_search_path, results.cresults.latent_state_index_in_batch, results.cresults.last_actions, results.cresults.virtual_to_play_batchs + diff --git a/lzero/mcts/tree_search/__init__.py b/lzero/mcts/tree_search/__init__.py index 581a67844..41a89655d 100644 --- a/lzero/mcts/tree_search/__init__.py +++ b/lzero/mcts/tree_search/__init__.py @@ -1,5 +1,6 @@ from .mcts_ctree import MuZeroMCTSCtree, EfficientZeroMCTSCtree, GumbelMuZeroMCTSCtree from .mcts_ctree_sampled import SampledEfficientZeroMCTSCtree +from .mcts_ctree_stochastic import StochasticMuZeroMCTSCtree from .mcts_ptree import MuZeroMCTSPtree, EfficientZeroMCTSPtree from .mcts_ptree_sampled import SampledEfficientZeroMCTSPtree from .mcts_ptree_stochastic import StochasticMuZeroMCTSPtree diff --git a/lzero/mcts/tree_search/mcts_ctree_stochastic.py b/lzero/mcts/tree_search/mcts_ctree_stochastic.py new file mode 100644 index 000000000..d2e9a2300 --- /dev/null +++ b/lzero/mcts/tree_search/mcts_ctree_stochastic.py @@ -0,0 +1,237 @@ +import copy +from typing import TYPE_CHECKING, List, Any, Union + +import numpy as np +import torch +from easydict import EasyDict + +from lzero.policy import InverseScalarTransform, to_detach_cpu_numpy +from lzero.mcts.ctree.ctree_stochastic_muzero import stochastic_mz_tree + + +# ============================================================== +# MuZero +# ============================================================== + + +class StochasticMuZeroMCTSCtree(object): + """ + Overview: + MCTSCtree for MuZero. The core ``batch_traverse`` and ``batch_backpropagate`` function is implemented in C++. + + Interfaces: + __init__, roots, search + """ + + config = dict( + # (float) The alpha value used in the Dirichlet distribution for exploration at the root node of the search tree. + root_dirichlet_alpha=0.3, + # (float) The noise weight at the root node of the search tree. + root_noise_weight=0.25, + # (int) The base constant used in the PUCT formula for balancing exploration and exploitation during tree search. + pb_c_base=19652, + # (float) The initialization constant used in the PUCT formula for balancing exploration and exploitation during tree search. + pb_c_init=1.25, + # (float) The maximum change in value allowed during the backup step of the search tree update. + value_delta_max=0.01, + ) + + @classmethod + def default_config(cls: type) -> EasyDict: + cfg = EasyDict(copy.deepcopy(cls.config)) + cfg.cfg_type = cls.__name__ + 'Dict' + return cfg + + def __init__(self, cfg: EasyDict = None) -> None: + """ + Overview: + Use the default configuration mechanism. If a user passes in a cfg with a key that matches an existing key + in the default configuration, the user-provided value will override the default configuration. Otherwise, + the default configuration will be used. + """ + default_config = self.default_config() + default_config.update(cfg) + self._cfg = default_config + self.inverse_scalar_transform_handle = InverseScalarTransform( + self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution + ) + + @classmethod + def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any], chance_space_size: int=2) -> "stochastic_mz_tree.Roots": + """ + Overview: + The initialization of CRoots with root num and legal action lists. + Arguments: + - root_num (:obj:`int`): the number of the current root. + - legal_action_list (:obj:`list`): the vector of the legal action of this root. + """ + from lzero.mcts.ctree.ctree_stochastic_muzero import stochastic_mz_tree as ctree + return ctree.Roots(active_collect_env_num, legal_actions, chance_space_size) + + def search( + self, roots: Any, model: torch.nn.Module, latent_state_roots: List[Any], to_play_batch: Union[int, + List[Any]] + ) -> None: + """ + Overview: + Do MCTS for the roots (a batch of root nodes in parallel). Parallel in model inference. + Use the cpp ctree. + Arguments: + - roots (:obj:`Any`): a batch of expanded root nodes + - latent_state_roots (:obj:`list`): the hidden states of the roots + - to_play_batch (:obj:`list`): the to_play_batch list used in in self-play-mode board games + """ + with torch.no_grad(): + model.eval() + + # preparation some constant + batch_size = roots.num + pb_c_base, pb_c_init, discount_factor = self._cfg.pb_c_base, self._cfg.pb_c_init, self._cfg.discount_factor + # the data storage of latent states: storing the latent state of all the nodes in the search. + latent_state_batch_in_search_path = [latent_state_roots] + + # minimax value storage + min_max_stats_lst = stochastic_mz_tree.MinMaxStatsList(batch_size) + min_max_stats_lst.set_delta(self._cfg.value_delta_max) + + for simulation_index in range(self._cfg.num_simulations): + # In each simulation, we expanded a new node, so in one search, we have ``num_simulations`` num of nodes at most. + + latent_states = [] + + # prepare a result wrapper to transport results between python and c++ parts + results = stochastic_mz_tree.ResultsWrapper(num=batch_size) + + # latent_state_index_in_search_path: the first index of leaf node states in latent_state_batch_in_search_path, i.e. is current_latent_state_index in one the search. + # latent_state_index_in_batch: the second index of leaf node states in latent_state_batch_in_search_path, i.e. the index in the batch, whose maximum is ``batch_size``. + # e.g. the latent state of the leaf node in (x, y) is latent_state_batch_in_search_path[x, y], where x is current_latent_state_index, y is batch_index. + # The index of value prefix hidden state of the leaf node are in the same manner. + """ + MCTS stage 1: Selection + Each simulation starts from the internal root state s0, and finishes when the simulation reaches a leaf node s_l. + """ + leaf_node_is_chance, latent_state_index_in_search_path, latent_state_index_in_batch, last_actions, virtual_to_play_batch = stochastic_mz_tree.batch_traverse( + roots, pb_c_base, pb_c_init, discount_factor, min_max_stats_lst, results, + copy.deepcopy(to_play_batch) + ) + + # obtain the latent state for leaf node + for ix, iy in zip(latent_state_index_in_search_path, latent_state_index_in_batch): + latent_states.append(latent_state_batch_in_search_path[ix][iy]) + + latent_states = torch.from_numpy(np.asarray(latent_states)).to(self._cfg.device).float() + # .long() is only for discrete action + last_actions = torch.from_numpy(np.asarray(last_actions)).to(self._cfg.device).long() + """ + MCTS stage 2: Expansion + At the final time-step l of the simulation, the next_latent_state and reward/value_prefix are computed by the dynamics function. + Then we calculate the policy_logits and value for the leaf node (next_latent_state) by the prediction function. (aka. evaluation) + MCTS stage 3: Backup + At the end of the simulation, the statistics along the trajectory are updated. + """ + # network_output = model.recurrent_inference(latent_states, last_actions) + + num = len(leaf_node_is_chance) + leaf_idx_list = list(range(num)) + latent_state_batch = [None] * num + value_batch = [None] * num + reward_batch = [None] * num + policy_logits_batch = [None] * num + child_is_chance_batch = [None] * num + chance_nodes_index = [] + decision_nodes_index = [] + + for i, leaf_node_is_chance_ in enumerate(leaf_node_is_chance): + if leaf_node_is_chance_: + chance_nodes_index.append(i) + else: + decision_nodes_index.append(i) + + def process_nodes(nodes_index, is_chance): + # Return early if nodes_index is empty + if not nodes_index: + return + + # Slice and stack latent_states and last_actions based on nodes_index + latent_states_stack = torch.stack([latent_states[i] for i in nodes_index], dim=0) + last_actions_stack = torch.stack([last_actions[i] for i in nodes_index], dim=0) + + # Pass the stacked batch through the recurrent_inference function + network_output_batch = model.recurrent_inference(latent_states_stack, + last_actions_stack, + afterstate=not is_chance) + + # Split the batch output into separate nodes + latent_state_splits = torch.split(network_output_batch.latent_state, 1, dim=0) + value_splits = torch.split(network_output_batch.value, 1, dim=0) + reward_splits = torch.split(network_output_batch.reward, 1, dim=0) + policy_logits_splits = torch.split(network_output_batch.policy_logits, 1, dim=0) + + for i, (latent_state, value, reward, policy_logits) in zip(nodes_index, + zip(latent_state_splits, value_splits, + reward_splits, + policy_logits_splits)): + if not model.training: + value = self.inverse_scalar_transform_handle(value).detach().cpu().numpy() + reward = self.inverse_scalar_transform_handle(reward).detach().cpu().numpy() + latent_state = latent_state.detach().cpu().numpy() + policy_logits = policy_logits.detach().cpu().numpy() + + latent_state_batch[i] = latent_state + value_batch[i] = value.reshape(-1).tolist() + reward_batch[i] = reward.reshape(-1).tolist() + policy_logits_batch[i] = policy_logits.tolist() + child_is_chance_batch[i] = is_chance + + process_nodes(chance_nodes_index, True) + process_nodes(decision_nodes_index, False) + chance_nodes = chance_nodes_index + decision_nodes = decision_nodes_index + + value_batch_chance = [value_batch[leaf_idx] for leaf_idx in chance_nodes] + value_batch_decision = [value_batch[leaf_idx] for leaf_idx in decision_nodes] + reward_batch_chance = [reward_batch[leaf_idx] for leaf_idx in chance_nodes] + reward_batch_decision = [reward_batch[leaf_idx] for leaf_idx in decision_nodes] + policy_logits_batch_chance = [policy_logits_batch[leaf_idx] for leaf_idx in chance_nodes] + policy_logits_batch_decision = [policy_logits_batch[leaf_idx] for leaf_idx in decision_nodes] + + latent_state_batch = np.concatenate(latent_state_batch, axis=0) + latent_state_batch_in_search_path.append(latent_state_batch) + current_latent_state_index = simulation_index + 1 + + if(len(chance_nodes) > 0): + value_batch_chance = np.concatenate(value_batch_chance, axis=0).reshape(-1).tolist() + reward_batch_chance = np.concatenate(reward_batch_chance, axis=0).reshape(-1).tolist() + policy_logits_batch_chance = np.concatenate(policy_logits_batch_chance, axis=0).tolist() + stochastic_mz_tree.batch_backpropagate( + current_latent_state_index, discount_factor, reward_batch_chance, value_batch_chance, policy_logits_batch_chance, + min_max_stats_lst, results, virtual_to_play_batch, child_is_chance_batch, chance_nodes + ) + if(len(decision_nodes)>0): + value_batch_decision = np.concatenate(value_batch_decision, axis=0).reshape(-1).tolist() + reward_batch_decision = np.concatenate(reward_batch_decision, axis=0).reshape(-1).tolist() + policy_logits_batch_decision = np.concatenate(policy_logits_batch_decision, axis=0).tolist() + stochastic_mz_tree.batch_backpropagate( + current_latent_state_index, discount_factor, reward_batch_decision, value_batch_decision, policy_logits_batch_decision, + min_max_stats_lst, results, virtual_to_play_batch, child_is_chance_batch, decision_nodes + ) + + + + + # latent_state_batch = np.concatenate(latent_state_batch, axis=0) + # value_batch = np.concatenate(value_batch, axis=0).reshape(-1).tolist() + # reward_batch = np.concatenate(reward_batch, axis=0).reshape(-1).tolist() + # policy_logits_batch = np.concatenate(policy_logits_batch, axis=0).tolist() + # latent_state_batch_in_search_path.append(latent_state_batch) + + # # In ``batch_backpropagate()``, we first expand the leaf node using ``the policy_logits`` and + # # ``reward`` predicted by the model, then perform backpropagation along the search path to update the + # # statistics. + + # # NOTE: simulation_index + 1 is very important, which is the depth of the current leaf node. + # current_latent_state_index = simulation_index + 1 + # stochastic_mz_tree.batch_backpropagate( + # current_latent_state_index, discount_factor, reward_batch, value_batch, policy_logits_batch, + # min_max_stats_lst, results, virtual_to_play_batch, child_is_chance_batch, leaf_idx_list + # ) diff --git a/lzero/policy/stochastic_muzero.py b/lzero/policy/stochastic_muzero.py index 401095952..9087c4cbf 100644 --- a/lzero/policy/stochastic_muzero.py +++ b/lzero/policy/stochastic_muzero.py @@ -10,7 +10,7 @@ from ding.utils import POLICY_REGISTRY from torch.nn import L1Loss -# from lzero.mcts import StochasticMuZeroMCTSCtree as MCTSCtree +from lzero.mcts import StochasticMuZeroMCTSCtree as MCTSCtree from lzero.mcts import StochasticMuZeroMCTSPtree as MCTSPtree from lzero.model import ImageTransforms from lzero.policy import scalar_transform, InverseScalarTransform, cross_entropy_loss, phi_transform, \ diff --git a/zoo/atari/config/atari_stochastic_muzero.py b/zoo/atari/config/atari_stochastic_muzero.py index 1fbefb784..f28975382 100644 --- a/zoo/atari/config/atari_stochastic_muzero.py +++ b/zoo/atari/config/atari_stochastic_muzero.py @@ -54,7 +54,8 @@ norm_type='BN', ), cuda=True, - mcts_ctree=False, + gumbel_algo=False, + mcts_ctree=True, env_type='not_board_games', game_segment_length=400, use_augmentation=True, From 14e382202f242bbd97efb693f13d80933bf6070c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= Date: Tue, 13 Jun 2023 15:16:57 +0800 Subject: [PATCH 03/28] add box2d, classic conrol, and 2048 config --- ...o.py => atari_stochastic_muzero_config.py} | 12 +- ...narlander_disc_stochastic_muzero_config.py | 97 +++ .../cartpole_stochastic_muzero_config.py | 96 +++ zoo/game_2048/config/muzero_2048_config.py | 89 +++ .../config/rule_based_2048_config.py | 201 +++++++ .../config/stochastic_muzero_2048_config.py | 93 +++ zoo/game_2048/envs/__init__.py | 0 zoo/game_2048/envs/game_2048_env.py | 559 ++++++++++++++++++ 8 files changed, 1141 insertions(+), 6 deletions(-) rename zoo/atari/config/{atari_stochastic_muzero.py => atari_stochastic_muzero_config.py} (96%) create mode 100644 zoo/box2d/lunarlander/config/lunarlander_disc_stochastic_muzero_config.py create mode 100644 zoo/classic_control/cartpole/config/cartpole_stochastic_muzero_config.py create mode 100644 zoo/game_2048/config/muzero_2048_config.py create mode 100644 zoo/game_2048/config/rule_based_2048_config.py create mode 100644 zoo/game_2048/config/stochastic_muzero_2048_config.py create mode 100644 zoo/game_2048/envs/__init__.py create mode 100644 zoo/game_2048/envs/game_2048_env.py diff --git a/zoo/atari/config/atari_stochastic_muzero.py b/zoo/atari/config/atari_stochastic_muzero_config.py similarity index 96% rename from zoo/atari/config/atari_stochastic_muzero.py rename to zoo/atari/config/atari_stochastic_muzero_config.py index f28975382..8083d579d 100644 --- a/zoo/atari/config/atari_stochastic_muzero.py +++ b/zoo/atari/config/atari_stochastic_muzero_config.py @@ -17,12 +17,12 @@ # ============================================================== # begin of the most frequently changed config specified by the user # ============================================================== -collector_env_num = 1 -n_episode = 1 -evaluator_env_num = 1 -num_simulations = 5 -update_per_collect = 10 -batch_size = 25 +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 50 +update_per_collect = 1000 +batch_size = 256 max_env_step = int(1e6) reanalyze_ratio = 0. chance_space_size = 2 diff --git a/zoo/box2d/lunarlander/config/lunarlander_disc_stochastic_muzero_config.py b/zoo/box2d/lunarlander/config/lunarlander_disc_stochastic_muzero_config.py new file mode 100644 index 000000000..444d5386b --- /dev/null +++ b/zoo/box2d/lunarlander/config/lunarlander_disc_stochastic_muzero_config.py @@ -0,0 +1,97 @@ +from easydict import EasyDict + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 50 +update_per_collect = 200 +batch_size = 256 +max_env_step = int(5e6) +reanalyze_ratio = 0. +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +lunarlander_muzero_config = dict( + exp_name=f'data_stochastic_mz_ctree/lunarlander_stochastic_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', + env=dict( + env_name='LunarLander-v2', + continuous=False, + manually_discretization=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + observation_shape=8, + action_space_size=4, + chance_space_size=2, + model_type='conv', + lstm_hidden_size=256, + latent_state_dim=256, + self_supervised_learning_loss=True, # NOTE: default is False. + discrete_action_encoding_type='one_hot', + res_connection_in_dynamics=True, + norm_type='BN', + ), + cuda=True, + gumbel_algo=False, + env_type='not_board_games', + game_segment_length=200, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='Adam', + lr_piecewise_constant_decay=False, + learning_rate=0.003, + ssl_loss_weight=2, # NOTE: default is 0. + grad_clip_value=0.5, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + eval_freq=int(1e3), + replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), +) +lunarlander_muzero_config = EasyDict(lunarlander_muzero_config) +main_config = lunarlander_muzero_config + +lunarlander_muzero_create_config = dict( + env=dict( + type='lunarlander', + import_names=['zoo.box2d.lunarlander.envs.lunarlander_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='stochastic_muzero', + import_names=['lzero.policy.stochastic_muzero'], + ), + collector=dict( + type='episode_muzero', + get_train_sample=True, + import_names=['lzero.worker.muzero_collector'], + ) +) +lunarlander_muzero_create_config = EasyDict(lunarlander_muzero_create_config) +create_config = lunarlander_muzero_create_config + +if __name__ == "__main__": + # Users can use different train entry by specifying the entry_type. + entry_type = "train_muzero" # options={"train_muzero", "train_muzero_with_gym_env"} + + if entry_type == "train_muzero": + from lzero.entry import train_muzero + train_muzero([main_config, create_config], seed=0, max_env_step=max_env_step) + elif entry_type == "train_muzero_with_gym_env": + """ + The ``train_muzero_with_gym_env`` entry means that the environment used in the training process is generated by wrapping the original gym environment with LightZeroEnvWrapper. + Users can refer to lzero/envs/wrappers for more details. + """ + from lzero.entry import train_muzero_with_gym_env + train_muzero_with_gym_env([main_config, create_config], seed=0, max_env_step=max_env_step) diff --git a/zoo/classic_control/cartpole/config/cartpole_stochastic_muzero_config.py b/zoo/classic_control/cartpole/config/cartpole_stochastic_muzero_config.py new file mode 100644 index 000000000..3b30fc4b1 --- /dev/null +++ b/zoo/classic_control/cartpole/config/cartpole_stochastic_muzero_config.py @@ -0,0 +1,96 @@ +from easydict import EasyDict + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 25 +update_per_collect = 100 +batch_size = 256 +max_env_step = int(1e5) +reanalyze_ratio = 0 +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +cartpole_stochastic_muzero_config = dict( + exp_name=f'data_stochastic_mz_ctree/cartpole_stochastic_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', + env=dict( + env_name='CartPole-v0', + continuous=False, + manually_discretization=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + observation_shape=4, + action_space_size=2, + chance_space_size=2, + model_type='conv', + lstm_hidden_size=128, + latent_state_dim=128, + self_supervised_learning_loss=True, # NOTE: default is False. + discrete_action_encoding_type='one_hot', + norm_type='BN', + ), + mcts_ctree=True, + gumbel_algo=False, + cuda=True, + env_type='not_board_games', + game_segment_length=50, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='Adam', + lr_piecewise_constant_decay=False, + learning_rate=0.003, + ssl_loss_weight=2, # NOTE: default is 0. + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + eval_freq=int(2e2), + replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), +) + +cartpole_stochastic_muzero_config = EasyDict(cartpole_stochastic_muzero_config) +main_config = cartpole_stochastic_muzero_config + +cartpole_stochastic_muzero_create_config = dict( + env=dict( + type='cartpole_lightzero', + import_names=['zoo.classic_control.cartpole.envs.cartpole_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='stochastic_muzero', + import_names=['lzero.policy.stochastic_muzero'], + ), + collector=dict( + type='episode_muzero', + import_names=['lzero.worker.muzero_collector'], + ) +) +cartpole_stochastic_muzero_create_config = EasyDict(cartpole_stochastic_muzero_create_config) +create_config = cartpole_stochastic_muzero_create_config + +if __name__ == "__main__": + # Users can use different train entry by specifying the entry_type. + entry_type = "train_muzero" # options={"train_muzero", "train_muzero_with_gym_env"} + + if entry_type == "train_muzero": + from lzero.entry import train_muzero + elif entry_type == "train_muzero_with_gym_env": + """ + The ``train_muzero_with_gym_env`` entry means that the environment used in the training process is generated by wrapping the original gym environment with LightZeroEnvWrapper. + Users can refer to lzero/envs/wrappers for more details. + """ + from lzero.entry import train_muzero_with_gym_env as train_muzero + + train_muzero([main_config, create_config], seed=0, max_env_step=max_env_step) \ No newline at end of file diff --git a/zoo/game_2048/config/muzero_2048_config.py b/zoo/game_2048/config/muzero_2048_config.py new file mode 100644 index 000000000..52bbcd5c7 --- /dev/null +++ b/zoo/game_2048/config/muzero_2048_config.py @@ -0,0 +1,89 @@ +from easydict import EasyDict + +env_name = 'game_2048' +action_space_size = 4 +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 1 +n_episode = 1 +evaluator_env_num = 1 +num_simulations = 5 +update_per_collect = 3 +batch_size = 5 +max_env_step = int(1e6) +reanalyze_ratio = 0. +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +atari_muzero_config = dict( + exp_name=f'data_mz_ctree/game_2048_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_bs{batch_size}_sslw2_rew-morm-true_seed0', + env=dict( + stop_value=int(1e6), + env_name=env_name, + obs_shape=(16, 4, 4), + obs_type='dict_observation', + reward_normalize=True, + reward_scale=100, + max_tile=int(2**16), # 2**11=2048, 2**16=65536 + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + observation_shape=(16, 4, 4), + action_space_size=action_space_size, + image_channel=16, + # NOTE: whether to use the self_supervised_learning_loss. default is False + self_supervised_learning_loss=True, + ), + mcts_ctree=True, + gumbel_algo=False, + cuda=True, + env_type='not_board_games', + game_segment_length=200, + update_per_collect=update_per_collect, + batch_size=batch_size, + td_steps=10, + discount_factor=0.999, + manual_temperature_decay=True, + optim_type='SGD', + lr_piecewise_constant_decay=True, + learning_rate=0.2, # init lr for manually decay schedule + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + ssl_loss_weight=2, # default is 0 + n_episode=n_episode, + eval_freq=int(2e3), + replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), +) +atari_muzero_config = EasyDict(atari_muzero_config) +main_config = atari_muzero_config + +atari_muzero_create_config = dict( + env=dict( + type='game_2048', + import_names=['zoo.game_2048.envs.game_2048_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='muzero', + import_names=['lzero.policy.muzero'], + ), + collector=dict( + type='episode_muzero', + import_names=['lzero.worker.muzero_collector'], + ) +) +atari_muzero_create_config = EasyDict(atari_muzero_create_config) +create_config = atari_muzero_create_config + +if __name__ == "__main__": + from lzero.entry import train_muzero + train_muzero([main_config, create_config], seed=0, max_env_step=max_env_step) diff --git a/zoo/game_2048/config/rule_based_2048_config.py b/zoo/game_2048/config/rule_based_2048_config.py new file mode 100644 index 000000000..191411ead --- /dev/null +++ b/zoo/game_2048/config/rule_based_2048_config.py @@ -0,0 +1,201 @@ +import numpy as np +from zoo.game_2048.envs.game_2048_env import Game2048Env, IllegalMove +import pytest +from easydict import EasyDict + +from typing import Tuple, Union +from rich import print +from functools import lru_cache +import time +import numpy as np + + +def rule_based_search(grid: np.array, fast_search: bool = True) -> int: + + model1 = np.array([[16, 15, 14, 13], [9, 10, 11, 12], [8, 7, 6, 5], [1, 2, 2, 4]]) + model2 = np.array([[16, 15, 12, 4], [14, 13, 11, 3], [10, 9, 8, 2], [7, 6, 5, 1]]) + model3 = np.array([[16, 15, 14, 4], [13, 12, 11, 3], [10, 9, 8, 2], [7, 6, 5, 1]]) + + @lru_cache(maxsize=512) + def get_model_score(value, i, j): + result = np.zeros(3 * 8) + for k, m in enumerate([model1, model2, model3]): + start = k * 8 + result[start] += m[i, j] * value + result[start + 1] += m[i, 3 - j] * value + result[start + 2] += m[j, i] * value + result[start + 3] += m[3 - j, i] * value + result[start + 4] += m[3 - i, 3 - j] * value + result[start + 5] += m[3 - i, j] * value + result[start + 6] += m[j, 3 - i] * value + result[start + 7] += m[3 - j, 3 - i] * value + return result + + def get_score(grid: np.array) -> float: + result = np.zeros(3 * 8) + for i in range(4): + for j in range(4): + if grid[i, j] != 0: + result += get_model_score(grid[i, j], i, j) + # for k, m in enumerate([model1, model2, model3]): + # start = k * 8 + # value = grid[i, j] # whether use log2 here + # result[start] += m[i, j] * value + # result[start + 1] += m[i, 3 - j] * value + # result[start + 2] += m[j, i] * value + # result[start + 3] += m[3 - j, i] * value + # result[start + 4] += m[3 - i, 3 - j] * value + # result[start + 5] += m[3 - i, j] * value + # result[start + 6] += m[j, 3 - i] * value + # result[start + 7] += m[3 - j, 3 - i] * value + + return result.max() + + def expectation_search(grid: np.array, depth: int, chance_node: bool) -> Tuple[float, Union[int, None]]: + if depth == 0: + return get_score(grid), None + if chance_node: + cum_score = 0. + if fast_search: + choices = [[2, 0.9]] + else: + choices = zip([2, 4], [0.9, 0.1]) + for value, prob in choices: + value, prob = 2, 0.9 + for i in range(4): + for j in range(4): + if grid[i, j] == 0: + grid[i, j] = value + cum_score += prob * expectation_search(grid, depth - 1, False)[0] + grid[i, j] = 0 + empty_count = np.sum(grid == 0) + cum_score /= empty_count + return cum_score, None + else: + best_score = 0 + best_action = None + # 0, 1, 2, 3 mean top, right, bottom, left + for dire in [0, 1, 2, 3]: + new_grid, move_flag, _ = move(grid, dire) + if move_flag: + score = expectation_search(new_grid, depth - 1, True)[0] + if score > best_score: + best_score = score + best_action = dire + return best_score, best_action + + # depth selection + grid_max = grid.max() + if grid_max >= 2048: + depth = 6 + elif grid_max >= 1024: + depth = 5 + else: + depth = 4 + # rule_based_search + _, best_action = expectation_search(grid, depth, False) + return best_action + + +def move(grid: np.array, action: int, game_score: int = 0) -> Tuple[np.array, bool, int]: + # execute action in 2048 game + # 0, 1, 2, 3 mean top, right, bottom, left + assert action in [0, 1, 2, 3], action + old_grid = grid + grid = np.copy(grid) + # rotate + if action == 0: + grid = np.rot90(grid) + elif action == 1: + grid = np.rot90(grid, k=3) + elif action == 2: + grid = np.rot90(grid, k=2) + # simple move + for i in range(4): + for j in range(3): + if grid[i, j] == 0: + grid[i, j] = grid[i, j + 1] + grid[i, j + 1] = 0 + # merge + for i in range(4): + for j in range(3): + if grid[i, j] == grid[i, j + 1]: + game_score += 2 * grid[i, j] + grid[i, j] *= 2 + grid[i, j + 1] = 0 + # simple move + for i in range(4): + for j in range(3): + if grid[i, j] == 0: + grid[i, j] = grid[i, j + 1] + grid[i, j + 1] = 0 + # rotate back + if action == 0: + grid = np.rot90(grid, k=3) + elif action == 1: + grid = np.rot90(grid) + elif action == 2: + grid = np.rot90(grid, k=2) + move_flag = np.any(old_grid != grid) + return grid, move_flag, game_score + + +def generate(grid: np.array) -> np.array: + # random generate a new number in empty location + # 2 or 4 + number = np.random.choice([2, 4], p=[0.9, 0.1]) + # get empty location + empty = np.where(grid == 0) + # random select one + index = np.random.randint(len(empty[0])) + # set new number + grid[empty[0][index], empty[1][index]] = number + # return new grid + return grid + + + + +config = EasyDict(dict( + env_name="game_2048_env_2048", + save_replay_gif=False, + replay_path_gif=None, + replay_path=None, + act_scale=True, + channel_last=True, + obs_type='array', + reward_normalize=True, + reward_scale=100, + max_tile=int(2**16), + delay_reward_step=0, + prob_random_agent=0., + max_episode_steps=int(1e4), + is_collect=False, + ignore_legal_actions=True, + need_flatten=False, +)) + +if __name__ == "__main__": + game_2048_env = Game2048Env(config) + obs = game_2048_env.reset() + print('init board state: ') + game_2048_env.render() + step = 0 + while True: + # action = env.human_to_action() + print('='*20) + grid = obs.astype(np.int64) + action = game_2048_env.random_action() + action = rule_based_search(grid) + if(action == 1): + action=2 + elif(action == 2): + action = 1 + obs, reward, done, info = game_2048_env.step(action) + step += 1 + print(f"step: {step}, action: {action}, reward: {reward}, raw_reward: {info['raw_reward']}") + game_2048_env.render() + + if done: + print('total_step_number: {}'.format(step)) + break diff --git a/zoo/game_2048/config/stochastic_muzero_2048_config.py b/zoo/game_2048/config/stochastic_muzero_2048_config.py new file mode 100644 index 000000000..79c04d577 --- /dev/null +++ b/zoo/game_2048/config/stochastic_muzero_2048_config.py @@ -0,0 +1,93 @@ +from easydict import EasyDict + +env_name = 'game_2048' +action_space_size = 4 +chance_space_size = 4 +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 1 +n_episode = 1 +evaluator_env_num = 1 +num_simulations = 5 # TODO(pu):100 +update_per_collect = 3 +batch_size = 5 +max_env_step = int(1e6) +reanalyze_ratio = 0. +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +game_2048_stochastic_muzero_config = dict( + exp_name=f'data_stochastic_mz_ctree/game_2048_stochastic_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_bs{batch_size}_sslw2_rew-morm-true_seed0', + env=dict( + stop_value=int(1e6), + env_name=env_name, + obs_shape=(16, 4, 4), + obs_type='dict_observation', + reward_normalize=True, + reward_scale=100, + max_tile=int(2**16), # 2**11=2048, 2**16=65536 + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + observation_shape=(16, 4, 4), + action_space_size=action_space_size, + chance_space_size=chance_space_size, + image_channel=16, + # NOTE: whether to use the self_supervised_learning_loss. default is False + self_supervised_learning_loss=True, # default is False + discrete_action_encoding_type='one_hot', + norm_type='BN', + ), + mcts_ctree=True, + gumbel_algo=False, + cuda=True, + env_type='not_board_games', + game_segment_length=400, + update_per_collect=update_per_collect, + batch_size=batch_size, + td_steps=10, + discount_factor=0.999, + manual_temperature_decay=True, + optim_type='SGD', + lr_piecewise_constant_decay=True, + learning_rate=0.2, # init lr for manually decay schedule + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + ssl_loss_weight=2, # default is 0 + n_episode=n_episode, + eval_freq=int(2e3), + replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), +) +game_2048_stochastic_muzero_config = EasyDict(game_2048_stochastic_muzero_config) +main_config = game_2048_stochastic_muzero_config + +game_2048_stochastic_muzero_create_config = dict( + env=dict( + type='game_2048', + import_names=['zoo.game_2048.envs.game_2048_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='stochastic_muzero', + import_names=['lzero.policy.stochastic_muzero'], + ), + collector=dict( + type='episode_muzero', + import_names=['lzero.worker.muzero_collector'], + ) +) +game_2048_stochastic_muzero_create_config = EasyDict(game_2048_stochastic_muzero_create_config) +create_config = game_2048_stochastic_muzero_create_config + +if __name__ == "__main__": + from lzero.entry import train_muzero + train_muzero([main_config, create_config], seed=0, max_env_step=max_env_step) diff --git a/zoo/game_2048/envs/__init__.py b/zoo/game_2048/envs/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/zoo/game_2048/envs/game_2048_env.py b/zoo/game_2048/envs/game_2048_env.py new file mode 100644 index 000000000..4379a6c31 --- /dev/null +++ b/zoo/game_2048/envs/game_2048_env.py @@ -0,0 +1,559 @@ +from __future__ import print_function + +import copy +import itertools +import logging +import sys +from typing import List + +import gym +import numpy as np +from PIL import Image, ImageDraw, ImageFont +from ding.envs import BaseEnvTimestep +from ding.torch_utils import to_ndarray +from ding.utils import ENV_REGISTRY +from easydict import EasyDict +from gym import spaces +from gym.utils import seeding +from six import StringIO + + +@ENV_REGISTRY.register('game_2048') +class Game2048Env(gym.Env): + config = dict( + env_name="game_2048", + save_replay_gif=False, + replay_path_gif=None, + replay_path=None, + act_scale=True, + channel_last=True, + obs_type='raw_observation', # options=['raw_observation', 'dict_observation', 'array'] + reward_normalize=True, + reward_scale=100, + max_tile=int(2**16), # 2**11=2048, 2**16=65536 + delay_reward_step=0, + prob_random_agent=0., + max_episode_steps=int(1e6), + is_collect=True, + ignore_legal_actions = True, + need_flatten = False, + ) + metadata = {'render.modes': ['human', 'ansi', 'rgb_array']} + + @classmethod + def default_config(cls: type) -> EasyDict: + cfg = EasyDict(copy.deepcopy(cls.config)) + cfg.cfg_type = cls.__name__ + 'Dict' + return cfg + + def __init__(self, cfg: dict) -> None: + self._cfg = cfg + self._init_flag = False + self._env_name = cfg.env_name + self._replay_path = cfg.get('replay_path', None) + self._replay_path_gif = cfg.get('replay_path_gif', None) + self._save_replay_gif = cfg.get('save_replay_gif', False) + self._save_replay_count = 0 + self.channel_last = cfg.channel_last + self.obs_type = cfg.obs_type + self.reward_normalize = cfg.reward_normalize + self.reward_scale = cfg.reward_scale + self.max_tile = cfg.max_tile + self.max_episode_steps = cfg.max_episode_steps + self.is_collect = cfg.is_collect + self.ignore_legal_actions = cfg.ignore_legal_actions + self.need_flatten = cfg.need_flatten + self.chance = 0 + + self.size = 4 + self.w = self.size + self.h = self.size + self.squares = self.size * self.size + + self.max_value = 2 + + self.episode_return = 0 + # Members for gym implementation: + self._action_space = spaces.Discrete(4) + self._observation_space = spaces.Box(0, 1, (self.w, self.h, self.squares), dtype=int) + + self.set_illegal_move_reward(0.) + self.set_max_tile(max_tile=self.max_tile) + + if self.reward_normalize: + self._reward_range = (0., self.max_tile) + else: + self._reward_range = (0., self.max_tile) + + # TODO(pu): why + self.grid_size = 70 + + # Initialise the random seed of the gym environment. + self.seed() + + def seed(self, seed=None, seed1=None): + """Set the random seed for the gym environment.""" + self.np_random, seed = seeding.np_random(seed) + return [seed] + + def set_illegal_move_reward(self, reward): + """Define the reward/penalty for performing an illegal move. Also need + to update the reward range for this.""" + # Guess that the maximum reward is also 2**squares though you'll probably never get that. + # (assume that illegal move reward is the lowest value that can be returned + # TODO: check that this is correct + self.illegal_move_reward = reward + self.reward_range = (self.illegal_move_reward, float(2 ** self.squares)) + + def set_max_tile(self, max_tile: int = 2048): + """ + Define the maximum tile that will end the game (e.g. 2048). None means no limit. + This does not affect the state returned. + """ + assert max_tile is None or isinstance(max_tile, int) + self.max_tile = max_tile + + def reset(self): + """Reset the game board-matrix and add 2 tiles.""" + self.episode_length = 0 + self.board = np.zeros((self.h, self.w), np.int32) + self.episode_return = 0 + self._final_eval_reward = 0.0 + self.should_done = False + self.max_value = 2 + + logging.debug("Adding tiles") + # TODO(pu): why add_tiles twice? + self.add_random_2_4_tile() + self.add_random_2_4_tile() + + action_mask = np.zeros(4, 'int8') + action_mask[self.legal_actions] = 1 + + observation = encoding_board(self.board) + observation = observation.astype(np.float32) + assert observation.shape == (4, 4, 16) + + if not self.channel_last: + # move channel dim to fist axis + # (W, H, C) -> (C, W, H) + # e.g. (4, 4, 16) -> (16, 4, 4) + observation = np.transpose(observation, [2, 0, 1]) + if self.need_flatten: + observation = observation.reshape(-1) + + if self.obs_type == 'dict_observation': + observation = {'observation': observation, 'action_mask': action_mask, 'to_play': -1, 'chance': self.chance} + elif self.obs_type == 'array': + observation = self.board + else: + observation = observation + return observation + + def step(self, action): + """Perform one step of the game. This involves moving and adding a new tile.""" + self.episode_length += 1 + info = {'illegal_move': False} + + if action not in self.legal_actions: + raise IllegalActionError(f"You input illegal action: {action}, the legal_actions are {self.legal_actions}. ") + + empty_num1 = len(self.get_empty_location()) + reward_eval = float(self.move(action)) + empty_num2 = len(self.get_empty_location()) + reward_collect = float(empty_num2 - empty_num1) + #reward_collect = float(empty_num1 - empty_num2) + max_num = np.max(self.board) + if max_num > self.max_value: + reward_collect += np.log2(max_num) * 0.1 + self.max_value = max_num + self.episode_return += reward_eval + assert reward_eval <= 2 ** (self.w * self.h) + self.add_random_2_4_tile() + done = self.is_end() + reward_collect = float(reward_collect) + reward_eval = float(reward_eval) + + if self.episode_length >= self.max_episode_steps: + # print("episode_length: {}".format(self.episode_length)) + done = True + + observation = encoding_board(self.board) + observation = observation.astype(np.float32) + + assert observation.shape == (4, 4, 16) + + if not self.channel_last: + # move channel dim to fist axis + # (W, H, C) -> (C, W, H) + # e.g. (4, 4, 16) -> (16, 4, 4) + observation = np.transpose(observation, [2, 0, 1]) + + if self.need_flatten: + observation = observation.reshape(-1) + action_mask = np.zeros(4, 'int8') + action_mask[self.legal_actions] = 1 + + if self.obs_type == 'dict_observation': + observation = {'observation': observation, 'action_mask': action_mask, 'to_play': -1, 'chance': self.chance} + elif self.obs_type == 'array': + observation = self.board + else: + observation = observation + + if self.reward_normalize: + reward_normalize = reward_collect + self._final_eval_reward += reward_normalize + reward = reward_collect + else: + self._final_eval_reward += reward_eval + reward = reward_eval + reward = to_ndarray([reward]).astype(np.float32) + + info = {"raw_reward": reward_eval, "max_tile": self.highest(), 'highest': self.highest()} + + if done: + info['eval_episode_return'] = self._final_eval_reward + + if self.reward_normalize: + return BaseEnvTimestep(observation, reward, done, info) + else: + return BaseEnvTimestep(observation, reward, done, info) + + def render(self, mode='human'): + if mode == 'rgb_array': + black = (0, 0, 0) + grey = (128, 128, 128) + white = (255, 255, 255) + tile_colour_map = { + 2: (255, 0, 0), + 4: (224, 32, 0), + 8: (192, 64, 0), + 16: (160, 96, 0), + 32: (128, 128, 0), + 64: (96, 160, 0), + 128: (64, 192, 0), + 256: (32, 224, 0), + 512: (0, 255, 0), + 1024: (0, 224, 32), + 2048: (0, 192, 64), + 4096: (0, 160, 96), + } + grid_size = self.grid_size + + # Render with Pillow + pil_board = Image.new("RGB", (grid_size * 4, grid_size * 4)) + draw = ImageDraw.Draw(pil_board) + draw.rectangle([0, 0, 4 * grid_size, 4 * grid_size], grey) + fnt = ImageFont.truetype('Arial.ttf', 30) + + for y in range(4): + for x in range(4): + o = self.get(y, x) + if o: + draw.rectangle([x * grid_size, y * grid_size, (x + 1) * grid_size, (y + 1) * grid_size], + tile_colour_map[o]) + (text_x_size, text_y_size) = draw.textsize(str(o), font=fnt) + draw.text((x * grid_size + (grid_size - text_x_size) // 2, + y * grid_size + (grid_size - text_y_size) // 2), str(o), font=fnt, fill=white) + assert text_x_size < grid_size + assert text_y_size < grid_size + + return np.asarray(pil_board).swapaxes(0, 1) + + outfile = StringIO() if mode == 'ansi' else sys.stdout + s = 'Current Return: {}, '.format(self.episode_return) + s += 'Highest Tile: {}\n'.format(self.highest()) + npa = np.array(self.board) + grid = npa.reshape((self.size, self.size)) + s += "{}\n".format(grid) + outfile.write(s) + return outfile + + # Implementation of game logic for 2048 + def add_random_2_4_tile(self): + """Add a tile with value 2 or 4 with different probabilities.""" + possible_tiles = np.array([2, 4]) + tile_probabilities = np.array([0.9, 0.1]) + val = self.np_random.choice(possible_tiles, 1, p=tile_probabilities)[0] + empty_location = self.get_empty_location() + # assert empty_location.shape[0] + if empty_location.shape[0] == 0: + self.should_done = True + return + empty_idx = self.np_random.choice(empty_location.shape[0]) + empty = empty_location[empty_idx] + logging.debug("Adding %s at %s", val, (empty[0], empty[1])) + val_chance_cum = 0 + # if val == 4: + # val_chance_cum = 16 + self.chance = val_chance_cum + 4 * empty[0] + empty[1] + self.set(empty[0], empty[1], val) + + def get(self, x, y): + """Get the value of one square.""" + return self.board[x, y] + + def set(self, x, y, val): + """Set the value of one square.""" + self.board[x, y] = val + + def get_empty_location(self): + """Return a 2d numpy array with the location of empty squares.""" + return np.argwhere(self.board == 0) + + def highest(self): + """Report the highest tile on the board.""" + return np.max(self.board) + + def move(self, direction, trial=False): + """ + Overview: + Perform one move of the game. Shift things to one side then, + combine. directions 0, 1, 2, 3 are up, right, down, left. + Returns the reward that [would have] got. + Arguments: + - direction (:obj:`int`): The direction to move. + - trial (:obj:`bool`): Whether this is a trial move. + """ + if not trial: + if direction == 0: + logging.debug("Up") + elif direction == 1: + logging.debug("Right") + elif direction == 2: + logging.debug("Down") + elif direction == 3: + logging.debug("Left") + + changed = False + move_reward = 0 + dir_div_two = int(direction / 2) + dir_mod_two = int(direction % 2) + # 0 for towards up or left, 1 for towards bottom or right + shift_direction = dir_mod_two ^ dir_div_two + + # Construct a range for extracting row/column into a list + rx = list(range(self.w)) + ry = list(range(self.h)) + + if dir_mod_two == 0: + # Up or down, split into columns + for y in range(self.h): + old = [self.get(x, y) for x in rx] + (new, ms) = self.shift(old, shift_direction) + move_reward += ms + if old != new: + changed = True + if not trial: + for x in rx: + self.set(x, y, new[x]) + else: + # Left or right, split into rows + for x in range(self.w): + old = [self.get(x, y) for y in ry] + (new, ms) = self.shift(old, shift_direction) + move_reward += ms + if old != new: + changed = True + if not trial: + for y in ry: + self.set(x, y, new[y]) + # if not changed: + # raise IllegalMove + + return move_reward + + @property + def legal_actions(self): + """ + Overview: + Return the legal actions for the current state. + Arguments: + - None + Returns: + - legal_actions (:obj:`list`): The legal actions. + """ + if self.ignore_legal_actions: + return [0,1,2,3] + legal_actions = [] + for direction in range(4): + changed = False + move_reward = 0 + dir_div_two = int(direction / 2) + dir_mod_two = int(direction % 2) + # 0 for towards up or left, 1 for towards bottom or right + shift_direction = dir_mod_two ^ dir_div_two + + # Construct a range for extracting row/column into a list + rx = list(range(self.w)) + ry = list(range(self.h)) + + if dir_mod_two == 0: + # Up or down, split into columns + for y in range(self.h): + old = [self.get(x, y) for x in rx] + (new, move_reward_tmp) = self.shift(old, shift_direction) + move_reward += move_reward_tmp + if old != new: + changed = True + else: + # Left or right, split into rows + for x in range(self.w): + old = [self.get(x, y) for y in ry] + (new, move_reward_tmp) = self.shift(old, shift_direction) + move_reward += move_reward_tmp + if old != new: + changed = True + + if changed: + legal_actions.append(direction) + + return legal_actions + + def combine(self, shifted_row): + """Combine same tiles when moving to one side. This function always + shifts towards the left. Also count the reward of combined tiles.""" + move_reward = 0 + combined_row = [0] * self.size + skip = False + output_index = 0 + for p in pairwise(shifted_row): + if skip: + skip = False + continue + combined_row[output_index] = p[0] + if p[0] == p[1]: + combined_row[output_index] += p[1] + move_reward += p[0] + p[1] + # Skip the next thing in the list. + skip = True + output_index += 1 + if shifted_row and not skip: + combined_row[output_index] = shifted_row[-1] + + return combined_row, move_reward + + def shift(self, row, direction): + """Shift one row left (direction == 0) or right (direction == 1), combining if required.""" + length = len(row) + assert length == self.size + # assert direction == 0 or direction == 1 + + # Shift all non-zero digits up + shifted_row = [i for i in row if i != 0] + + # Reverse list to handle shifting to the right + if direction: + shifted_row.reverse() + + (combined_row, move_reward) = self.combine(shifted_row) + + # Reverse list to handle shifting to the right + if direction: + combined_row.reverse() + + assert len(combined_row) == self.size + return combined_row, move_reward + + def is_end(self): + """Has the game ended. Game ends if there is a tile equal to the limit + or there are no legal moves. If there are empty spaces then there + must be legal moves.""" + + if self.max_tile is not None and self.highest() == self.max_tile: + return True + elif len(self.legal_actions) == 0: + # the agent don't have legal_actions to move, so the episode is done + return True + elif self.should_done: + return True + else: + return False + + def get_board(self): + """Get the whole board-matrix, useful for testing.""" + return self.board + + def set_board(self, new_board): + """Set the whole board-matrix, useful for testing.""" + self.board = new_board + + def random_action(self) -> np.ndarray: + random_action = self.action_space.sample() + if isinstance(random_action, np.ndarray): + pass + elif isinstance(random_action, int): + random_action = to_ndarray([random_action], dtype=np.int64) + return random_action + + @property + def observation_space(self) -> gym.spaces.Space: + return self._observation_space + + @property + def action_space(self) -> gym.spaces.Space: + return self._action_space + + @property + def reward_space(self) -> gym.spaces.Space: + return self._reward_range + + @staticmethod + def create_collector_env_cfg(cfg: dict) -> List[dict]: + collector_env_num = cfg.pop('collector_env_num') + cfg = copy.deepcopy(cfg) + # cfg.reward_normalize = True + # when collect data, sometimes we need to normalize the reward + # reward_normalize is determined by the config. + cfg.is_collect = True + return [cfg for _ in range(collector_env_num)] + + @staticmethod + def create_evaluator_env_cfg(cfg: dict) -> List[dict]: + evaluator_env_num = cfg.pop('evaluator_env_num') + cfg = copy.deepcopy(cfg) + # when evaluate, we don't need to normalize the reward. + cfg.reward_normalize = False + cfg.is_collect = False + return [cfg for _ in range(evaluator_env_num)] + + def __repr__(self) -> str: + return "LightZero 2048 Env." + + +def pairwise(iterable): + """s -> (s0,s1), (s1,s2), (s2, s3), ...""" + a, b = itertools.tee(iterable) + next(b, None) + return zip(a, b) + + +class IllegalMove(Exception): + pass + +class IllegalActionError(Exception): + pass + +def encoding_board(flat, num_of_template_tiles=16): + """ + Overview: + Convert an [4, 4] raw board into [4, 4, num_of_template_tiles] one-hot encoding. + Arguments: + - board (:obj:`np.ndarray`): the raw board + - num_of_template_tiles (:obj:`int`): the number of template_tiles + Returns: + - one_hot_board (:obj:`np.ndarray`): the one-hot encoding board + """ + # TODO(pu): the more elegant one-hot encoding implementation + # template_tiles is what each layer represents + # template_tiles = 2 ** (np.arange(num_of_template_tiles, dtype=int) + 1) + template_tiles = 2 ** (np.arange(num_of_template_tiles, dtype=int)) + template_tiles[0] = 0 + # layered is the flat board repeated num_of_template_tiles times + layered = np.repeat(flat[:, :, np.newaxis], num_of_template_tiles, axis=-1) + + # Now set the values in the board to 1 or zero depending on whether they match template_tiles. + # template_tiles is broadcast across a number of axes + one_hot_board = np.where(layered == template_tiles, 1, 0) + return one_hot_board \ No newline at end of file From 06c0558f5378e0c623b80ad0628d1f80dbbf7a6c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= Date: Fri, 16 Jun 2023 12:13:01 +0800 Subject: [PATCH 04/28] made corrections to the comments and naming issues --- .../mcts/tree_search/mcts_ctree_stochastic.py | 2 +- lzero/model/stochastic_muzero_model.py | 4 +- lzero/policy/stochastic_muzero.py | 76 ++++++++++++------- 3 files changed, 51 insertions(+), 31 deletions(-) diff --git a/lzero/mcts/tree_search/mcts_ctree_stochastic.py b/lzero/mcts/tree_search/mcts_ctree_stochastic.py index d2e9a2300..71b3c80b0 100644 --- a/lzero/mcts/tree_search/mcts_ctree_stochastic.py +++ b/lzero/mcts/tree_search/mcts_ctree_stochastic.py @@ -17,7 +17,7 @@ class StochasticMuZeroMCTSCtree(object): """ Overview: - MCTSCtree for MuZero. The core ``batch_traverse`` and ``batch_backpropagate`` function is implemented in C++. + MCTSCtree for Stochastic MuZero. The core ``batch_traverse`` and ``batch_backpropagate`` function is implemented in C++. Interfaces: __init__, roots, search diff --git a/lzero/model/stochastic_muzero_model.py b/lzero/model/stochastic_muzero_model.py index 95bc1c5ba..345c0497f 100644 --- a/lzero/model/stochastic_muzero_model.py +++ b/lzero/model/stochastic_muzero_model.py @@ -54,6 +54,7 @@ def __init__( Arguments: - observation_shape (:obj:`SequenceType`): Observation space shape, e.g. [C, W, H]=[12, 96, 96] for Atari. - action_space_size: (:obj:`int`): Action space size, usually an integer number for discrete action space. + - chance_space_size: (:obj:`int`): Chance space size, the action space for decision node, usually an integer number for discrete action space. - num_res_blocks (:obj:`int`): The number of res blocks in AlphaZero model. - num_channels (:obj:`int`): The channels of hidden states. - reward_head_channels (:obj:`int`): The channels of reward head. @@ -846,7 +847,8 @@ def __init__(self, self.onehot_argmax = StraightThroughEstimator() def forward(self, o_i): #https://openreview.net/pdf?id=X6D9bAHhBQ1 [page:5 chance outcome] - c_e_t = torch.nn.Softmax(-1)(self.encoder(o_i)) + # c_e_t = torch.nn.Softmax(-1)(self.encoder(o_i)) + c_e_t = self.encoder(o_i) #c_t= torch.zeros_like(c_e_t).scatter_(-1, torch.argmax(c_e_t, dim=-1,keepdim=True), 1.) c_t = self.onehot_argmax(c_e_t) return c_t,c_e_t diff --git a/lzero/policy/stochastic_muzero.py b/lzero/policy/stochastic_muzero.py index 9087c4cbf..712ae92f2 100644 --- a/lzero/policy/stochastic_muzero.py +++ b/lzero/policy/stochastic_muzero.py @@ -22,7 +22,7 @@ class StochasticMuZeroPolicy(Policy): """ Overview: - The policy class for MuZero. + The policy class for Stochastic MuZero. """ # The default_config for MuZero policy. @@ -121,6 +121,12 @@ class StochasticMuZeroPolicy(Policy): value_loss_weight=0.25, # (float) The weight of policy loss. policy_loss_weight=1, + # (float) The weight of afterstate policy loss. + afterstate_policy_loss_weight=1, + # (float) The weight of afterstate value loss. + afterstate_value_loss_weight=0.25, + # (float) The weight of vqvae encoder commitment loss. + commitment_loss_weight=1.0, # (float) The weight of ssl (self-supervised learning) loss. ssl_loss_weight=0, # (bool) Whether to use piecewise constant learning rate decay. @@ -191,7 +197,12 @@ def _init_learn(self) -> None: self._model.parameters(), lr=self._cfg.learning_rate, weight_decay=self._cfg.weight_decay ) elif self._cfg.optim_type == 'AdamW': - self._optimizer = configure_optimizers(model=self._model, weight_decay=self._cfg.weight_decay, learning_rate=self._cfg.learning_rate, device_type=self._cfg.device) + self._optimizer = configure_optimizers( + model=self._model, + weight_decay=self._cfg.weight_decay, + learning_rate=self._cfg.learning_rate, + device_type=self._cfg.device + ) if self._cfg.lr_piecewise_constant_decay: from torch.optim.lr_scheduler import LambdaLR @@ -238,17 +249,15 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in self._target_model.train() current_batch, target_batch = data - obs_batch_ori, action_batch, mask_batch, indices, weights, make_time = current_batch + obs_batch_orig, action_batch, mask_batch, indices, weights, make_time = current_batch target_reward, target_value, target_policy = target_batch - - - obs_batch, obs_target_batch = prepare_obs(obs_batch_ori, self._cfg) + obs_batch, obs_target_batch = prepare_obs(obs_batch_orig, self._cfg) encoder_image_list = [] encoder_image_list.append(obs_batch) - for zt in range(5): - beg_index = self._cfg.model.image_channel * zt - end_index = self._cfg.model.image_channel * (zt + self._cfg.model.frame_stack_num) + for i in range(self._cfg.num_unroll_steps): + beg_index = self._cfg.model.image_channel * i + end_index = self._cfg.model.image_channel * (i + self._cfg.model.frame_stack_num) encoder_image_list.append(obs_target_batch[:, beg_index:end_index, :, :]) # do augmentations @@ -315,7 +324,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in reward_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device) consistency_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device) - + afterstate_policy_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device) afterstate_value_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device) commitment_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device) @@ -326,20 +335,25 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # the core recurrent_inference in MuZero policy. # ============================================================== for step_i in range(self._cfg.num_unroll_steps): - # unroll with the dynamics function: predict the next ``latent_state``, ``reward``, - # given current ``latent_state`` and ``action``. - # And then predict policy_logits and value with the prediction function. - network_output = self._learn_model.recurrent_inference(latent_state, action_batch[:, step_i], afterstate=False) - after_state, a_reward, a_value, a_policy_logits = mz_network_output_unpack(network_output) - + # unroll with the afterstate dynamic function: predict 'afterstate state', + # given current ``state`` and ``action``. + # 'afterstate reward' is not used, we kept it for the sake of uniformity between decision nodes and chance nodes. + # And then predict afterstate policy_logits and afterstate value with the afterstate prediction function. + network_output = self._learn_model.recurrent_inference( + latent_state, action_batch[:, step_i], afterstate=False + ) + after_state, afterstate_reward, afterstate_value, afterstate_policy_logits = mz_network_output_unpack(network_output) + + # concat consecutive frames to calculate ground truth chance former_frame = encoder_image_list[step_i] - latter_frame = encoder_image_list[step_i+1] + latter_frame = encoder_image_list[step_i + 1] concat_frame = torch.cat((former_frame, latter_frame), dim=1) - chance_code, encode_output = self._learn_model._encode_vqvae(concat_frame) - #chance_code, encode_output = self._learn_model._encode_vqvae(encoder_image_list[step_i]) chance_code_long = torch.argmax(chance_code, dim=1).long().unsqueeze(-1) + # unroll with the dynamics function: predict the next ``latent_state``, ``reward``, + # given current ``after_state`` and ``chance_long``. + # And then predict policy_logits and value with the prediction function. network_output = self._learn_model.recurrent_inference(after_state, chance_code_long, afterstate=True) latent_state, reward, value, policy_logits = mz_network_output_unpack(network_output) @@ -380,10 +394,10 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # NOTE: the +=. # ============================================================== policy_loss += cross_entropy_loss(policy_logits, target_policy[:, step_i + 1]) - afterstate_policy_loss += cross_entropy_loss(a_policy_logits, chance_code) + afterstate_policy_loss += cross_entropy_loss(afterstate_policy_logits, chance_code) commitment_loss += cross_entropy_loss(encode_output, chance_code) - afterstate_value_loss += cross_entropy_loss(a_value, target_value_categorical[:, step_i]) + afterstate_value_loss += cross_entropy_loss(afterstate_value, target_value_categorical[:, step_i]) value_loss += cross_entropy_loss(value, target_value_categorical[:, step_i + 1]) reward_loss += cross_entropy_loss(reward, target_reward_categorical[:, step_i]) @@ -407,10 +421,9 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # weighted loss with masks (some invalid states which are out of trajectory.) loss = ( self._cfg.ssl_loss_weight * consistency_loss + self._cfg.policy_loss_weight * policy_loss + - self._cfg.value_loss_weight * value_loss + self._cfg.reward_loss_weight * reward_loss - + self._cfg.policy_loss_weight * afterstate_policy_loss + self._cfg.value_loss_weight * afterstate_value_loss - + self._cfg.policy_loss_weight * commitment_loss - + self._cfg.value_loss_weight * value_loss + self._cfg.reward_loss_weight * reward_loss + + self._cfg.afterstate_policy_loss_weight * afterstate_policy_loss + + self._cfg.afterstate_value_loss_weight * afterstate_value_loss + self._cfg.commitment_loss_weight * commitment_loss ) weighted_total_loss = (weights * loss).mean() @@ -432,10 +445,15 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # packing loss info for tensorboard logging loss_info = ( - weighted_total_loss.item(), loss.mean().item(), policy_loss.mean().item(), reward_loss.mean().item(), - value_loss.mean().item(), consistency_loss.mean(), afterstate_policy_loss.mean().item(), - afterstate_value_loss.mean().item(), commitment_loss.mean().item(), - + weighted_total_loss.item(), + loss.mean().item(), + policy_loss.mean().item(), + reward_loss.mean().item(), + value_loss.mean().item(), + consistency_loss.mean(), + afterstate_policy_loss.mean().item(), + afterstate_value_loss.mean().item(), + commitment_loss.mean().item(), ) if self._cfg.monitor_extra_statistics: predicted_rewards = torch.stack(predicted_rewards).transpose(1, 0).squeeze(-1) From fa88aab0a2d982b6007489f4998b8b38ea86549a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= Date: Fri, 16 Jun 2023 12:16:58 +0800 Subject: [PATCH 05/28] made corrections to the comments and naming issues --- lzero/model/stochastic_muzero_model.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/lzero/model/stochastic_muzero_model.py b/lzero/model/stochastic_muzero_model.py index 345c0497f..0162687b3 100644 --- a/lzero/model/stochastic_muzero_model.py +++ b/lzero/model/stochastic_muzero_model.py @@ -83,11 +83,6 @@ def __init__( we don't need this module. """ super(StochasticMuZeroModel, self).__init__() - if isinstance(observation_shape, int) or len(observation_shape) == 1: - # for vector obs input, e.g. classical control ad box2d environments - # to be compatible with LightZero model/policy, transform to shape: [C, W, H] - observation_shape = [1, observation_shape, 1] - self.categorical_distribution = categorical_distribution if self.categorical_distribution: self.reward_support_size = reward_support_size From b7a3fbafbeab43128b5a80a4ff5b20579afb472e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= Date: Mon, 10 Jul 2023 22:15:24 +0800 Subject: [PATCH 06/28] ok --- .../buffer/game_buffer_stochastic_muzero.py | 11 ++- lzero/mcts/buffer/game_segment.py | 11 ++- lzero/model/stochastic_muzero_model.py | 4 +- lzero/policy/stochastic_muzero.py | 12 ++- lzero/worker/muzero_collector.py | 11 ++- zoo/game_2048/config/explicit_2048_config.py | 94 +++++++++++++++++++ zoo/game_2048/envs/game_2048_env.py | 5 +- 7 files changed, 136 insertions(+), 12 deletions(-) create mode 100644 zoo/game_2048/config/explicit_2048_config.py diff --git a/lzero/mcts/buffer/game_buffer_stochastic_muzero.py b/lzero/mcts/buffer/game_buffer_stochastic_muzero.py index 28c978a3a..3f93e695b 100644 --- a/lzero/mcts/buffer/game_buffer_stochastic_muzero.py +++ b/lzero/mcts/buffer/game_buffer_stochastic_muzero.py @@ -111,6 +111,7 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time_list = orig_data batch_size = len(batch_index_list) obs_list, action_list, mask_list = [], [], [] + chance_list = [] # prepare the inputs of a batch for i in range(batch_size): game = game_segment_list[i] @@ -118,6 +119,8 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: actions_tmp = game.action_segment[pos_in_game_segment:pos_in_game_segment + self._cfg.num_unroll_steps].tolist() + chances_tmp = game.chance_segment[1 + pos_in_game_segment:1 + pos_in_game_segment + + self._cfg.num_unroll_steps].tolist() # add mask for invalid actions (out of trajectory) mask_tmp = [1. for i in range(len(actions_tmp))] mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps - len(mask_tmp))] @@ -127,7 +130,10 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: np.random.randint(0, game.action_space_size) for _ in range(self._cfg.num_unroll_steps - len(actions_tmp)) ] - + chances_tmp += [ + np.random.randint(0, game.action_space_size) + for _ in range(self._cfg.num_unroll_steps - len(chances_tmp)) + ] # obtain the input observations # pad if length of obs in game_segment is less than stack+num_unroll_steps # e.g. stack+num_unroll_steps 4+5 @@ -138,12 +144,13 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: ) action_list.append(actions_tmp) mask_list.append(mask_tmp) + chance_list.append(chances_tmp) # formalize the input observations obs_list = prepare_observation(obs_list, self._cfg.model.model_type) # formalize the inputs of a batch - current_batch = [obs_list, action_list, mask_list, batch_index_list, weights_list, make_time_list] + current_batch = [obs_list, action_list, mask_list, batch_index_list, weights_list, make_time_list, chance_list] for i in range(len(current_batch)): current_batch[i] = np.asarray(current_batch[i]) diff --git a/lzero/mcts/buffer/game_segment.py b/lzero/mcts/buffer/game_segment.py index be95c8448..b0d9cfdb0 100644 --- a/lzero/mcts/buffer/game_segment.py +++ b/lzero/mcts/buffer/game_segment.py @@ -64,6 +64,7 @@ def __init__(self, action_space: int, game_segment_length: int = 200, config: Ea self.action_mask_segment = [] self.to_play_segment = [] + self.chance_segment = [] self.target_values = [] self.target_rewards = [] @@ -128,7 +129,8 @@ def append( obs: np.ndarray, reward: np.ndarray, action_mask: np.ndarray = None, - to_play: int = -1 + to_play: int = -1, + chance: np.ndarray=0, ) -> None: """ Overview: @@ -140,10 +142,11 @@ def append( self.action_mask_segment.append(action_mask) self.to_play_segment.append(to_play) + self.chance_segment.append(chance) def pad_over( self, next_segment_observations: List, next_segment_rewards: List, next_segment_root_values: List, - next_segment_child_visits: List, next_segment_improved_policy: List = None + next_segment_child_visits: List, next_chances: List = None,next_segment_improved_policy: List = None, ) -> None: """ Overview: @@ -184,6 +187,8 @@ def pad_over( if self.config.gumbel_algo: for improved_policy in next_segment_improved_policy: self.improved_policy_probs.append(improved_policy) + for chances in next_chances: + self.chance_segment.append(chances) def get_targets(self, timestep: int) -> Tuple: """ @@ -253,6 +258,7 @@ def game_segment_to_array(self) -> None: self.action_mask_segment = np.array(self.action_mask_segment) self.to_play_segment = np.array(self.to_play_segment) + self.chance_segment = np.array(self.chance_segment) def reset(self, init_observations: np.ndarray) -> None: """ @@ -271,6 +277,7 @@ def reset(self, init_observations: np.ndarray) -> None: self.action_mask_segment = [] self.to_play_segment = [] + self.chance_segment = [] assert len(init_observations) == self.frame_stack_num diff --git a/lzero/model/stochastic_muzero_model.py b/lzero/model/stochastic_muzero_model.py index 0162687b3..7a008deb8 100644 --- a/lzero/model/stochastic_muzero_model.py +++ b/lzero/model/stochastic_muzero_model.py @@ -842,8 +842,8 @@ def __init__(self, self.onehot_argmax = StraightThroughEstimator() def forward(self, o_i): #https://openreview.net/pdf?id=X6D9bAHhBQ1 [page:5 chance outcome] - # c_e_t = torch.nn.Softmax(-1)(self.encoder(o_i)) - c_e_t = self.encoder(o_i) + c_e_t = torch.nn.Softmax(-1)(self.encoder(o_i)) + #c_e_t = self.encoder(o_i) #c_t= torch.zeros_like(c_e_t).scatter_(-1, torch.argmax(c_e_t, dim=-1,keepdim=True), 1.) c_t = self.onehot_argmax(c_e_t) return c_t,c_e_t diff --git a/lzero/policy/stochastic_muzero.py b/lzero/policy/stochastic_muzero.py index 712ae92f2..a071f30be 100644 --- a/lzero/policy/stochastic_muzero.py +++ b/lzero/policy/stochastic_muzero.py @@ -159,6 +159,7 @@ class StochasticMuZeroPolicy(Policy): root_dirichlet_alpha=0.3, # (float) The noise weight at the root node of the search tree. root_noise_weight=0.25, + explicit_chance_label = False, ) def default_model(self) -> Tuple[str, List[str]]: @@ -249,8 +250,12 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in self._target_model.train() current_batch, target_batch = data - obs_batch_orig, action_batch, mask_batch, indices, weights, make_time = current_batch + obs_batch_orig, action_batch, mask_batch, indices, weights, make_time, chance_batch = current_batch target_reward, target_value, target_policy = target_batch + + if self._cfg.explicit_chance_label: + chance_batch = torch.LongTensor(chance_batch).to(self._cfg.device) + chance_batch =torch.nn.functional.one_hot(chance_batch, self._cfg.model.chance_space_size) obs_batch, obs_target_batch = prepare_obs(obs_batch_orig, self._cfg) encoder_image_list = [] @@ -349,6 +354,8 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in latter_frame = encoder_image_list[step_i + 1] concat_frame = torch.cat((former_frame, latter_frame), dim=1) chance_code, encode_output = self._learn_model._encode_vqvae(concat_frame) + if self._cfg.explicit_chance_label: + chance_code = chance_batch[:, step_i] chance_code_long = torch.argmax(chance_code, dim=1).long().unsqueeze(-1) # unroll with the dynamics function: predict the next ``latent_state``, ``reward``, @@ -395,7 +402,8 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # ============================================================== policy_loss += cross_entropy_loss(policy_logits, target_policy[:, step_i + 1]) afterstate_policy_loss += cross_entropy_loss(afterstate_policy_logits, chance_code) - commitment_loss += cross_entropy_loss(encode_output, chance_code) + # commitment_loss += cross_entropy_loss(encode_output, chance_code) + commitment_loss += torch.nn.MSELoss()(encode_output, chance_code) * 0.01 afterstate_value_loss += cross_entropy_loss(afterstate_value, target_value_categorical[:, step_i]) value_loss += cross_entropy_loss(value, target_value_categorical[:, step_i + 1]) diff --git a/lzero/worker/muzero_collector.py b/lzero/worker/muzero_collector.py index fffe6e84e..dd56a5846 100644 --- a/lzero/worker/muzero_collector.py +++ b/lzero/worker/muzero_collector.py @@ -232,6 +232,7 @@ def pad_and_save_last_trajectory(self, i, last_game_segments, last_game_prioriti end_index = beg_index + self.unroll_plus_td_steps - 1 pad_reward_lst = game_segments[i].reward_segment[beg_index:end_index] + chance_lst = game_segments[i].chance_segment[beg_index:end_index] beg_index = 0 end_index = beg_index + self.unroll_plus_td_steps @@ -245,7 +246,8 @@ def pad_and_save_last_trajectory(self, i, last_game_segments, last_game_prioriti if self.policy_config.gumbel_algo: last_game_segments[i].pad_over(pad_obs_lst, pad_reward_lst, pad_root_values_lst, pad_child_visits_lst, next_segment_improved_policy = pad_improved_policy_prob) else: - last_game_segments[i].pad_over(pad_obs_lst, pad_reward_lst, pad_root_values_lst, pad_child_visits_lst) + #last_game_segments[i].pad_over(pad_obs_lst, pad_reward_lst, pad_root_values_lst, pad_child_visits_lst) + last_game_segments[i].pad_over(pad_obs_lst, pad_reward_lst, pad_root_values_lst, pad_child_visits_lst, chance_lst) """ Note: game_segment element shape: @@ -314,6 +316,7 @@ def collect(self, action_mask_dict = {i: to_ndarray(init_obs[i]['action_mask']) for i in range(env_nums)} to_play_dict = {i: to_ndarray(init_obs[i]['to_play']) for i in range(env_nums)} + chance_dict = {i: to_ndarray(init_obs[i]['chance']) for i in range(env_nums)} game_segments = [ GameSegment( @@ -367,8 +370,10 @@ def collect(self, action_mask_dict = {env_id: action_mask_dict[env_id] for env_id in ready_env_id} to_play_dict = {env_id: to_play_dict[env_id] for env_id in ready_env_id} + chance_dict = {env_id: chance_dict[env_id] for env_id in ready_env_id} action_mask = [action_mask_dict[env_id] for env_id in ready_env_id] to_play = [to_play_dict[env_id] for env_id in ready_env_id] + chance = [chance_dict[env_id] for env_id in ready_env_id] stack_obs = to_ndarray(stack_obs) @@ -455,13 +460,14 @@ def collect(self, # in ``game_segments[env_id].init``, we have append o_{t} in ``self.obs_segment`` game_segments[env_id].append( actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id], - to_play_dict[env_id] + to_play_dict[env_id], chance_dict[env_id] ) # NOTE: the position of code snippet is very important. # the obs['action_mask'] and obs['to_play'] is corresponding to next action action_mask_dict[env_id] = to_ndarray(obs['action_mask']) to_play_dict[env_id] = to_ndarray(obs['to_play']) + chance_dict[env_id] = to_ndarray(obs['chance']) dones[env_id] = done visit_entropies_lst[env_id] += visit_entropy_dict[env_id] @@ -581,6 +587,7 @@ def collect(self, action_mask_dict[env_id] = to_ndarray(init_obs[env_id]['action_mask']) to_play_dict[env_id] = to_ndarray(init_obs[env_id]['to_play']) + chance_dict[env_id] = to_ndarray(init_obs[env_id]['chance']) game_segments[env_id] = GameSegment( self._env.action_space, diff --git a/zoo/game_2048/config/explicit_2048_config.py b/zoo/game_2048/config/explicit_2048_config.py new file mode 100644 index 000000000..dcd743cdb --- /dev/null +++ b/zoo/game_2048/config/explicit_2048_config.py @@ -0,0 +1,94 @@ +from easydict import EasyDict + +env_name = 'game_2048' +action_space_size = 4 +chance_space_size = 32 +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 1 +n_episode = 1 +evaluator_env_num = 1 +num_simulations = 5 # TODO(pu):100 +update_per_collect = 3 +batch_size = 7 +max_env_step = int(1e6) +reanalyze_ratio = 0. +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +game_2048_stochastic_muzero_config = dict( + exp_name=f'data_stochastic_mz_ctree/game_2048_stochastic_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_bs{batch_size}_sslw2_rew-morm-true_seed0', + env=dict( + stop_value=int(1e6), + env_name=env_name, + obs_shape=(16, 4, 4), + obs_type='dict_observation', + reward_normalize=True, + reward_scale=100, + max_tile=int(2**16), # 2**11=2048, 2**16=65536 + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + observation_shape=(16, 4, 4), + action_space_size=action_space_size, + chance_space_size=chance_space_size, + image_channel=16, + # NOTE: whether to use the self_supervised_learning_loss. default is False + self_supervised_learning_loss=True, # default is False + discrete_action_encoding_type='one_hot', + norm_type='BN', + ), + mcts_ctree=True, + explicit_chance_label=True, + gumbel_algo=False, + cuda=True, + env_type='not_board_games', + game_segment_length=400, + update_per_collect=update_per_collect, + batch_size=batch_size, + td_steps=10, + discount_factor=0.999, + manual_temperature_decay=True, + optim_type='SGD', + lr_piecewise_constant_decay=True, + learning_rate=0.2, # init lr for manually decay schedule + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + ssl_loss_weight=2, # default is 0 + n_episode=n_episode, + eval_freq=int(2e3), + replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), +) +game_2048_stochastic_muzero_config = EasyDict(game_2048_stochastic_muzero_config) +main_config = game_2048_stochastic_muzero_config + +game_2048_stochastic_muzero_create_config = dict( + env=dict( + type='game_2048', + import_names=['zoo.game_2048.envs.game_2048_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='stochastic_muzero', + import_names=['lzero.policy.stochastic_muzero'], + ), + collector=dict( + type='episode_muzero', + import_names=['lzero.worker.muzero_collector'], + ) +) +game_2048_stochastic_muzero_create_config = EasyDict(game_2048_stochastic_muzero_create_config) +create_config = game_2048_stochastic_muzero_create_config + +if __name__ == "__main__": + from lzero.entry import train_muzero + train_muzero([main_config, create_config], seed=0, max_env_step=max_env_step) diff --git a/zoo/game_2048/envs/game_2048_env.py b/zoo/game_2048/envs/game_2048_env.py index 4379a6c31..4bb4437c1 100644 --- a/zoo/game_2048/envs/game_2048_env.py +++ b/zoo/game_2048/envs/game_2048_env.py @@ -195,6 +195,7 @@ def step(self, action): action_mask[self.legal_actions] = 1 if self.obs_type == 'dict_observation': + observation[0,0,0] = self.chance observation = {'observation': observation, 'action_mask': action_mask, 'to_play': -1, 'chance': self.chance} elif self.obs_type == 'array': observation = self.board @@ -285,8 +286,8 @@ def add_random_2_4_tile(self): empty = empty_location[empty_idx] logging.debug("Adding %s at %s", val, (empty[0], empty[1])) val_chance_cum = 0 - # if val == 4: - # val_chance_cum = 16 + if val == 4: + val_chance_cum = 16 self.chance = val_chance_cum + 4 * empty[0] + empty[1] self.set(empty[0], empty[1], val) From 9168a1bb5d288b66037ea5505dc768b1af98a001 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= Date: Mon, 10 Jul 2023 22:45:36 +0800 Subject: [PATCH 07/28] ok --- .../config/train_explicit_2048_config.py | 105 ++++++++++++++++++ 1 file changed, 105 insertions(+) create mode 100644 zoo/game_2048/config/train_explicit_2048_config.py diff --git a/zoo/game_2048/config/train_explicit_2048_config.py b/zoo/game_2048/config/train_explicit_2048_config.py new file mode 100644 index 000000000..c2fe37df5 --- /dev/null +++ b/zoo/game_2048/config/train_explicit_2048_config.py @@ -0,0 +1,105 @@ +from easydict import EasyDict + +env_name = 'game_2048' +action_space_size = 4 +chance_space_size= 16 +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 1 +n_episode = 1 +evaluator_env_num = 3 +num_simulations = 50 # TODO(pu):100 +update_per_collect = 100 +batch_size = 1024 +max_env_step = int(1e8) +reanalyze_ratio = 0. + +# collector_env_num = 1 +# n_episode = 2 +# evaluator_env_num = 1 +# num_simulations = 5 +# update_per_collect = 3 +# batch_size = 5 +# max_env_step = int(1e6) +# reanalyze_ratio = 0. +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +game_2048_stochastic_muzero_config = dict( + exp_name=f'june05_data_stomz_ctree/game_2048_stochastic_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_bs{batch_size}_sslw2_rew-morm-false_seed0', + env=dict( + stop_value=int(1e6), + env_name=env_name, + obs_shape=(16, 4, 4), + obs_type='dict_observation', + reward_normalize=False, + reward_scale=100, + max_tile=int(2**16), # 2**11=2048, 2**16=65536 + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + observation_shape=(16, 4, 4), + action_space_size=action_space_size, + chance_space_size=chance_space_size, + image_channel=16, + # NOTE: whether to use the self_supervised_learning_loss. default is False + self_supervised_learning_loss=True, # default is False + discrete_action_encoding_type='one_hot', + norm_type='BN', + ), + mcts_ctree=True, + gumbel_algo=False, + cuda=True, + env_type='not_board_games', + game_segment_length=200, + update_per_collect=update_per_collect, + batch_size=batch_size, + td_steps=10, + discount_factor=0.999, + manual_temperature_decay=True, + # optim_type='SGD', + # lr_piecewise_constant_decay=True, + # learning_rate=0.2, # init lr for manually decay schedule + optim_type='Adam', + lr_piecewise_constant_decay=False, + learning_rate=0.003, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + ssl_loss_weight=0, # default is 0 + n_episode=n_episode, + eval_freq=int(2e3), + replay_buffer_size=int(2e7), # the size/capacity of replay_buffer, in the terms of transitions. + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), +) +game_2048_stochastic_muzero_config = EasyDict(game_2048_stochastic_muzero_config) +main_config = game_2048_stochastic_muzero_config + +game_2048_stochastic_muzero_create_config = dict( + env=dict( + type='game_2048', + import_names=['zoo.game_2048.envs.game_2048_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='stochastic_muzero', + import_names=['lzero.policy.stochastic_muzero'], + ), + collector=dict( + type='episode_muzero', + import_names=['lzero.worker.muzero_collector'], + ) +) +game_2048_stochastic_muzero_create_config = EasyDict(game_2048_stochastic_muzero_create_config) +create_config = game_2048_stochastic_muzero_create_config + +if __name__ == "__main__": + from lzero.entry import train_muzero + train_muzero([main_config, create_config], seed=0, max_env_step=max_env_step) From 648f7473798695f9e26ffb6490f7270a2365d391 Mon Sep 17 00:00:00 2001 From: timothijoe Date: Mon, 10 Jul 2023 23:57:43 +0800 Subject: [PATCH 08/28] ok --- zoo/game_2048/config/train_explicit_2048_config.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/zoo/game_2048/config/train_explicit_2048_config.py b/zoo/game_2048/config/train_explicit_2048_config.py index c2fe37df5..ec44696c1 100644 --- a/zoo/game_2048/config/train_explicit_2048_config.py +++ b/zoo/game_2048/config/train_explicit_2048_config.py @@ -6,12 +6,12 @@ # ============================================================== # begin of the most frequently changed config specified by the user # ============================================================== -collector_env_num = 1 -n_episode = 1 +collector_env_num = 8 +n_episode = 8 evaluator_env_num = 3 num_simulations = 50 # TODO(pu):100 -update_per_collect = 100 -batch_size = 1024 +update_per_collect = 200 +batch_size = 512 max_env_step = int(1e8) reanalyze_ratio = 0. @@ -28,7 +28,7 @@ # ============================================================== game_2048_stochastic_muzero_config = dict( - exp_name=f'june05_data_stomz_ctree/game_2048_stochastic_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_bs{batch_size}_sslw2_rew-morm-false_seed0', + exp_name=f'july10_data_stomz_ctree/game_2048_stochastic_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_bs{batch_size}_sslw2_rew-morm-false_seed0', env=dict( stop_value=int(1e6), env_name=env_name, @@ -71,7 +71,7 @@ learning_rate=0.003, num_simulations=num_simulations, reanalyze_ratio=reanalyze_ratio, - ssl_loss_weight=0, # default is 0 + ssl_loss_weight=0.1, # default is 0 n_episode=n_episode, eval_freq=int(2e3), replay_buffer_size=int(2e7), # the size/capacity of replay_buffer, in the terms of transitions. From 4693272844da72e3741f773f76f23992cdf701f7 Mon Sep 17 00:00:00 2001 From: timothijoe Date: Tue, 11 Jul 2023 10:28:17 +0800 Subject: [PATCH 09/28] ok --- zoo/game_2048/envs/game_2048_env.py | 1 - 1 file changed, 1 deletion(-) diff --git a/zoo/game_2048/envs/game_2048_env.py b/zoo/game_2048/envs/game_2048_env.py index 4bb4437c1..2e52a0e43 100644 --- a/zoo/game_2048/envs/game_2048_env.py +++ b/zoo/game_2048/envs/game_2048_env.py @@ -195,7 +195,6 @@ def step(self, action): action_mask[self.legal_actions] = 1 if self.obs_type == 'dict_observation': - observation[0,0,0] = self.chance observation = {'observation': observation, 'action_mask': action_mask, 'to_play': -1, 'chance': self.chance} elif self.obs_type == 'array': observation = self.board From 11b4b7b5b7b76c02eace385596b85a65f473a5e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= <2402552459@qq.com> Date: Wed, 2 Aug 2023 19:05:08 +0800 Subject: [PATCH 10/28] polish(pu): polish game_2048_env --- .../mcts/tree_search/mcts_ctree_stochastic.py | 45 +-- lzero/model/stochastic_muzero_model.py | 165 ++++----- lzero/policy/stochastic_muzero.py | 28 +- zoo/game_2048/config/muzero_2048_config.py | 36 +- .../config/rule_based_2048_config.py | 22 +- .../config/stochastic_muzero_2048_config.py | 49 ++- zoo/game_2048/envs/game_2048_env.py | 327 +++++++++--------- 7 files changed, 356 insertions(+), 316 deletions(-) diff --git a/lzero/mcts/tree_search/mcts_ctree_stochastic.py b/lzero/mcts/tree_search/mcts_ctree_stochastic.py index 71b3c80b0..f248d8586 100644 --- a/lzero/mcts/tree_search/mcts_ctree_stochastic.py +++ b/lzero/mcts/tree_search/mcts_ctree_stochastic.py @@ -57,7 +57,8 @@ def __init__(self, cfg: EasyDict = None) -> None: ) @classmethod - def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any], chance_space_size: int=2) -> "stochastic_mz_tree.Roots": + def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any], + chance_space_size: int = 2) -> "stochastic_mz_tree.Roots": """ Overview: The initialization of CRoots with root num and legal action lists. @@ -70,7 +71,7 @@ def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any], chanc def search( self, roots: Any, model: torch.nn.Module, latent_state_roots: List[Any], to_play_batch: Union[int, - List[Any]] + List[Any]] ) -> None: """ Overview: @@ -185,7 +186,7 @@ def process_nodes(nodes_index, is_chance): process_nodes(chance_nodes_index, True) process_nodes(decision_nodes_index, False) - chance_nodes = chance_nodes_index + chance_nodes = chance_nodes_index decision_nodes = decision_nodes_index value_batch_chance = [value_batch[leaf_idx] for leaf_idx in chance_nodes] @@ -193,45 +194,33 @@ def process_nodes(nodes_index, is_chance): reward_batch_chance = [reward_batch[leaf_idx] for leaf_idx in chance_nodes] reward_batch_decision = [reward_batch[leaf_idx] for leaf_idx in decision_nodes] policy_logits_batch_chance = [policy_logits_batch[leaf_idx] for leaf_idx in chance_nodes] - policy_logits_batch_decision = [policy_logits_batch[leaf_idx] for leaf_idx in decision_nodes] + policy_logits_batch_decision = [policy_logits_batch[leaf_idx] for leaf_idx in decision_nodes] latent_state_batch = np.concatenate(latent_state_batch, axis=0) latent_state_batch_in_search_path.append(latent_state_batch) + + # In ``batch_backpropagate()``, we first expand the leaf node using ``the policy_logits`` and + # ``reward`` predicted by the model, then perform backpropagation along the search path to update the + # statistics. + + # NOTE: simulation_index + 1 is very important, which is the depth of the current leaf node. current_latent_state_index = simulation_index + 1 - if(len(chance_nodes) > 0): + if (len(chance_nodes) > 0): value_batch_chance = np.concatenate(value_batch_chance, axis=0).reshape(-1).tolist() reward_batch_chance = np.concatenate(reward_batch_chance, axis=0).reshape(-1).tolist() policy_logits_batch_chance = np.concatenate(policy_logits_batch_chance, axis=0).tolist() stochastic_mz_tree.batch_backpropagate( - current_latent_state_index, discount_factor, reward_batch_chance, value_batch_chance, policy_logits_batch_chance, + current_latent_state_index, discount_factor, reward_batch_chance, value_batch_chance, + policy_logits_batch_chance, min_max_stats_lst, results, virtual_to_play_batch, child_is_chance_batch, chance_nodes ) - if(len(decision_nodes)>0): + if (len(decision_nodes) > 0): value_batch_decision = np.concatenate(value_batch_decision, axis=0).reshape(-1).tolist() reward_batch_decision = np.concatenate(reward_batch_decision, axis=0).reshape(-1).tolist() policy_logits_batch_decision = np.concatenate(policy_logits_batch_decision, axis=0).tolist() stochastic_mz_tree.batch_backpropagate( - current_latent_state_index, discount_factor, reward_batch_decision, value_batch_decision, policy_logits_batch_decision, + current_latent_state_index, discount_factor, reward_batch_decision, value_batch_decision, + policy_logits_batch_decision, min_max_stats_lst, results, virtual_to_play_batch, child_is_chance_batch, decision_nodes ) - - - - - # latent_state_batch = np.concatenate(latent_state_batch, axis=0) - # value_batch = np.concatenate(value_batch, axis=0).reshape(-1).tolist() - # reward_batch = np.concatenate(reward_batch, axis=0).reshape(-1).tolist() - # policy_logits_batch = np.concatenate(policy_logits_batch, axis=0).tolist() - # latent_state_batch_in_search_path.append(latent_state_batch) - - # # In ``batch_backpropagate()``, we first expand the leaf node using ``the policy_logits`` and - # # ``reward`` predicted by the model, then perform backpropagation along the search path to update the - # # statistics. - - # # NOTE: simulation_index + 1 is very important, which is the depth of the current leaf node. - # current_latent_state_index = simulation_index + 1 - # stochastic_mz_tree.batch_backpropagate( - # current_latent_state_index, discount_factor, reward_batch, value_batch, policy_logits_batch, - # min_max_stats_lst, results, virtual_to_play_batch, child_is_chance_batch, leaf_idx_list - # ) diff --git a/lzero/model/stochastic_muzero_model.py b/lzero/model/stochastic_muzero_model.py index 0162687b3..0083a6ae6 100644 --- a/lzero/model/stochastic_muzero_model.py +++ b/lzero/model/stochastic_muzero_model.py @@ -19,32 +19,32 @@ class StochasticMuZeroModel(nn.Module): def __init__( - self, - observation_shape: SequenceType = (12, 96, 96), - action_space_size: int = 6, - chance_space_size: int = 2, - num_res_blocks: int = 1, - num_channels: int = 64, - reward_head_channels: int = 16, - value_head_channels: int = 16, - policy_head_channels: int = 16, - fc_reward_layers: SequenceType = [32], - fc_value_layers: SequenceType = [32], - fc_policy_layers: SequenceType = [32], - reward_support_size: int = 601, - value_support_size: int = 601, - proj_hid: int = 1024, - proj_out: int = 1024, - pred_hid: int = 512, - pred_out: int = 1024, - self_supervised_learning_loss: bool = False, - categorical_distribution: bool = True, - activation: nn.Module = nn.ReLU(inplace=True), - last_linear_layer_init_zero: bool = True, - state_norm: bool = False, - downsample: bool = False, - *args, - **kwargs + self, + observation_shape: SequenceType = (12, 96, 96), + action_space_size: int = 6, + chance_space_size: int = 2, + num_res_blocks: int = 1, + num_channels: int = 64, + reward_head_channels: int = 16, + value_head_channels: int = 16, + policy_head_channels: int = 16, + fc_reward_layers: SequenceType = [32], + fc_value_layers: SequenceType = [32], + fc_policy_layers: SequenceType = [32], + reward_support_size: int = 601, + value_support_size: int = 601, + proj_hid: int = 1024, + proj_out: int = 1024, + pred_hid: int = 512, + pred_out: int = 1024, + self_supervised_learning_loss: bool = False, + categorical_distribution: bool = True, + activation: nn.Module = nn.ReLU(inplace=True), + last_linear_layer_init_zero: bool = True, + state_norm: bool = False, + downsample: bool = False, + *args, + **kwargs ): """ Overview: @@ -90,7 +90,7 @@ def __init__( else: self.reward_support_size = 1 self.value_support_size = 1 - + self.action_space_size = action_space_size self.chance_space_size = chance_space_size @@ -125,7 +125,7 @@ def __init__( num_channels, downsample, ) - + self.encoder = Encoder_function( observation_shape, chance_space_size ) @@ -152,7 +152,7 @@ def __init__( flatten_output_size_for_policy_head, last_linear_layer_init_zero=self.last_linear_layer_init_zero, ) - + self.afterstate_dynamics_network = AfterstateDynamicsNetwork( num_res_blocks, num_channels + 1, @@ -232,7 +232,8 @@ def initial_inference(self, obs: torch.Tensor) -> MZNetworkOutput: latent_state, ) - def recurrent_inference(self, state: torch.Tensor, option: torch.Tensor, afterstate: bool = False) -> MZNetworkOutput: + def recurrent_inference(self, state: torch.Tensor, option: torch.Tensor, + afterstate: bool = False) -> MZNetworkOutput: """ Overview: Recurrent inference of MuZero model, which is the rollout step of the MuZero model. @@ -290,7 +291,7 @@ def _representation(self, observation: torch.Tensor) -> torch.Tensor: if self.state_norm: latent_state = renormalize(latent_state) return latent_state - + def _encode_vqvae(self, observation: torch.Tensor): output = self.encoder(observation) return output @@ -381,8 +382,9 @@ def _dynamics(self, latent_state: torch.Tensor, action: torch.Tensor) -> Tuple[t if self.state_norm: next_latent_state = renormalize(next_latent_state) return next_latent_state, reward - - def _afterstate_dynamics(self, latent_state: torch.Tensor, action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + + def _afterstate_dynamics(self, latent_state: torch.Tensor, action: torch.Tensor) -> Tuple[ + torch.Tensor, torch.Tensor]: """ Overview: Concatenate ``latent_state`` and ``action`` and use the dynamics network to predict ``next_latent_state`` @@ -433,9 +435,7 @@ def _afterstate_dynamics(self, latent_state: torch.Tensor, action: torch.Tensor) next_latent_state, reward = self.afterstate_dynamics_network(state_action_encoding) if self.state_norm: next_latent_state = renormalize(next_latent_state) - return next_latent_state, reward - - + return next_latent_state, reward def project(self, latent_state: torch.Tensor, with_grad: bool = True) -> torch.Tensor: """ @@ -483,15 +483,15 @@ def get_params_mean(self) -> float: class DynamicsNetwork(nn.Module): def __init__( - self, - num_res_blocks: int, - num_channels: int, - reward_head_channels: int, - fc_reward_layers: SequenceType, - output_support_size: int, - flatten_output_size_for_reward_head: int, - last_linear_layer_init_zero: bool = True, - activation: Optional[nn.Module] = nn.ReLU(inplace=True), + self, + num_res_blocks: int, + num_channels: int, + reward_head_channels: int, + fc_reward_layers: SequenceType, + output_support_size: int, + flatten_output_size_for_reward_head: int, + last_linear_layer_init_zero: bool = True, + activation: Optional[nn.Module] = nn.ReLU(inplace=True), ): """ Overview: @@ -580,18 +580,19 @@ def get_dynamic_mean(self) -> float: def get_reward_mean(self) -> float: return get_reward_mean(self) + class AfterstateDynamicsNetwork(nn.Module): def __init__( - self, - num_res_blocks: int, - num_channels: int, - reward_head_channels: int, - fc_reward_layers: SequenceType, - output_support_size: int, - flatten_output_size_for_reward_head: int, - last_linear_layer_init_zero: bool = True, - activation: Optional[nn.Module] = nn.ReLU(inplace=True), + self, + num_res_blocks: int, + num_channels: int, + reward_head_channels: int, + fc_reward_layers: SequenceType, + output_support_size: int, + flatten_output_size_for_reward_head: int, + last_linear_layer_init_zero: bool = True, + activation: Optional[nn.Module] = nn.ReLU(inplace=True), ): """ Overview: @@ -671,14 +672,14 @@ def forward(self, state_action_encoding: torch.Tensor) -> Tuple[torch.Tensor, to # use the fully connected layer to predict reward reward = self.fc_reward_head(x) - return afterstate_latent_state, reward + return afterstate_latent_state, reward def get_dynamic_mean(self) -> float: return get_dynamic_mean(self) def get_reward_mean(self) -> float: return get_reward_mean(self) - + class AfterstatePredictionNetwork(nn.Module): def __init__( @@ -757,7 +758,6 @@ def __init__( last_linear_layer_init_zero=last_linear_layer_init_zero ) - def forward(self, latent_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Overview: @@ -785,11 +785,12 @@ def forward(self, latent_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tenso value = self.fc_value(value) policy = self.fc_policy(policy) return policy, value - + + class ImgNet(nn.Module): def __init__(self, observation_space_dimensions, table_vec_dim=4): super(ImgNet, self).__init__() - self.conv1 = nn.Conv2d(observation_space_dimensions[0]*2, 32, 3, padding=1) + self.conv1 = nn.Conv2d(observation_space_dimensions[0] * 2, 32, 3, padding=1) self.conv2 = nn.Conv2d(32, 64, 3, padding=1) self.fc1 = nn.Linear(64 * observation_space_dimensions[1] * observation_space_dimensions[2], 128) self.fc2 = nn.Linear(128, 64) @@ -813,41 +814,15 @@ def __init__(self, super().__init__() self.action_space = action_dimension self.encoder = ImgNet(observation_space_dimensions, action_dimension) - # # # # add to sequence|first and recursive|,, whatever you need - # linear_in = nn.Linear(observation_space_dimensions, hidden_layer_dimensions) - # linear_mid = nn.Linear(hidden_layer_dimensions, hidden_layer_dimensions) - # linear_out = nn.Linear(hidden_layer_dimensions, state_dimension) - - # self.scale = nn.Tanh() - # layernom_init = nn.BatchNorm1d(observation_space_dimensions) - # layernorm_recur = nn.BatchNorm1d(hidden_layer_dimensions) - # # 0.1, 0.2 , 0.25 , 0.5 parameter (first two more recommended for rl) - # dropout = nn.Dropout(0.1) - # activation = nn.ELU() # , nn.ELU() , nn.GELU, nn.ELU() , nn.ELU - - # first_layer_sequence = [ - # linear_in, - # activation - # ] - - # recursive_layer_sequence = [ - # linear_mid, - # activation - # ] - - # sequence = first_layer_sequence + \ - # (recursive_layer_sequence*number_of_hidden_layer) - - # self.encoder = nn.Sequential(*tuple(sequence+[nn.Linear(hidden_layer_dimensions, action_dimension)])) self.onehot_argmax = StraightThroughEstimator() + def forward(self, o_i): - #https://openreview.net/pdf?id=X6D9bAHhBQ1 [page:5 chance outcome] + # https://openreview.net/pdf?id=X6D9bAHhBQ1 [page:5 chance outcome] # c_e_t = torch.nn.Softmax(-1)(self.encoder(o_i)) c_e_t = self.encoder(o_i) - #c_t= torch.zeros_like(c_e_t).scatter_(-1, torch.argmax(c_e_t, dim=-1,keepdim=True), 1.) + # c_t= torch.zeros_like(c_e_t).scatter_(-1, torch.argmax(c_e_t, dim=-1,keepdim=True), 1.) c_t = self.onehot_argmax(c_e_t) - return c_t,c_e_t - + return c_t, c_e_t class StraightThroughEstimator(nn.Module): @@ -857,14 +832,16 @@ def __init__(self): def forward(self, x): x = Onehot_argmax.apply(x) return x -#straight-through estimator is used during the backward to allow the gradients to flow only to the encoder during the backpropagation. + + +# straight-through estimator is used during the backward to allow the gradients to flow only to the encoder during the backpropagation. class Onehot_argmax(torch.autograd.Function): - #more information at : https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html + # more information at : https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html @staticmethod def forward(ctx, input): - #since the codebook is constant ,we can just use a transformation. no need to create a codebook and matmul c_e_t and codebook for argmax - return torch.zeros_like(input).scatter_(-1, torch.argmax(input, dim=-1,keepdim=True), 1.) + # since the codebook is constant ,we can just use a transformation. no need to create a codebook and matmul c_e_t and codebook for argmax + return torch.zeros_like(input).scatter_(-1, torch.argmax(input, dim=-1, keepdim=True), 1.) @staticmethod def backward(ctx, grad_output): - return grad_output \ No newline at end of file + return grad_output diff --git a/lzero/policy/stochastic_muzero.py b/lzero/policy/stochastic_muzero.py index 712ae92f2..13dc8fdd9 100644 --- a/lzero/policy/stochastic_muzero.py +++ b/lzero/policy/stochastic_muzero.py @@ -85,12 +85,18 @@ class StochasticMuZeroPolicy(Policy): augmentation=['shift', 'intensity'], # ******* learn ****** + # (bool) Whether to ignore the done flag in the training data. Typically, this value is set to False. + # However, for some environments with a fixed episode length, to ensure the accuracy of Q-value calculations, + # we should set it to True to avoid the influence of the done flag. + ignore_done=False, # (int) How many updates(iterations) to train after collector's one collection. # Bigger "update_per_collect" means bigger off-policy. # collect data -> update policy-> collect data -> ... # For different env, we have different episode_length, # we usually set update_per_collect = collector_env_num * episode_length / batch_size * reuse_factor update_per_collect=100, + # (float) The ratio of the collected data used for training. Only effective when ``update_per_collect`` is not None. + model_update_ratio=0.1, # (int) Minibatch size for one gradient descent. batch_size=256, # (str) Optimizer for training policy network. ['SGD', 'Adam', 'AdamW'] @@ -159,6 +165,24 @@ class StochasticMuZeroPolicy(Policy): root_dirichlet_alpha=0.3, # (float) The noise weight at the root node of the search tree. root_noise_weight=0.25, + + # ****** Explore by random collect ****** + # (int) The number of episodes to collect data randomly before training. + random_collect_episode_num=0, + + # ****** Explore by eps greedy ****** + eps=dict( + # (bool) Whether to use eps greedy exploration in collecting data. + eps_greedy_exploration_in_collect=False, + # 'linear', 'exp' + type='linear', + # (float) The start value of eps. + start=1., + # (float) The end value of eps. + end=0.05, + # (int) The decay steps from start to end eps. + decay=int(1e5), + ), ) def default_model(self) -> Tuple[str, List[str]]: @@ -335,7 +359,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # the core recurrent_inference in MuZero policy. # ============================================================== for step_i in range(self._cfg.num_unroll_steps): - # unroll with the afterstate dynamic function: predict 'afterstate state', + # unroll with the afterstate dynamic function: predict 'after_state', # given current ``state`` and ``action``. # 'afterstate reward' is not used, we kept it for the sake of uniformity between decision nodes and chance nodes. # And then predict afterstate policy_logits and afterstate value with the afterstate prediction function. @@ -348,6 +372,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in former_frame = encoder_image_list[step_i] latter_frame = encoder_image_list[step_i + 1] concat_frame = torch.cat((former_frame, latter_frame), dim=1) + chance_code, encode_output = self._learn_model._encode_vqvae(concat_frame) chance_code_long = torch.argmax(chance_code, dim=1).long().unsqueeze(-1) @@ -519,6 +544,7 @@ def _forward_collect( action_mask: list = None, temperature: float = 1, to_play: List = [-1], + epsilon: float = 0.25, ready_env_id=None ) -> Dict: """ diff --git a/zoo/game_2048/config/muzero_2048_config.py b/zoo/game_2048/config/muzero_2048_config.py index 52bbcd5c7..50afd7976 100644 --- a/zoo/game_2048/config/muzero_2048_config.py +++ b/zoo/game_2048/config/muzero_2048_config.py @@ -1,10 +1,21 @@ from easydict import EasyDict -env_name = 'game_2048' -action_space_size = 4 + # ============================================================== # begin of the most frequently changed config specified by the user # ============================================================== +env_name = 'game_2048' +action_space_size = 4 + +# collector_env_num = 8 +# n_episode = 8 +# evaluator_env_num = 3 +# num_simulations = 50 +# update_per_collect = 200 +# batch_size = 512 +# max_env_step = int(1e8) +# reanalyze_ratio = 0. + collector_env_num = 1 n_episode = 1 evaluator_env_num = 1 @@ -18,15 +29,16 @@ # ============================================================== atari_muzero_config = dict( - exp_name=f'data_mz_ctree/game_2048_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_bs{batch_size}_sslw2_rew-morm-true_seed0', + exp_name=f'data_mz_ctree/game_2048_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_bs{batch_size}_seed0', env=dict( stop_value=int(1e6), env_name=env_name, obs_shape=(16, 4, 4), obs_type='dict_observation', - reward_normalize=True, - reward_scale=100, - max_tile=int(2**16), # 2**11=2048, 2**16=65536 + raw_reward_type='raw', # 'merged_tiles_plus_log_max_tile_num' + reward_normalize=False, + reward_norm_scale=100, + max_tile=int(2 ** 16), # 2**11=2048, 2**16=65536 collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, n_evaluator_episode=evaluator_env_num, @@ -50,9 +62,15 @@ td_steps=10, discount_factor=0.999, manual_temperature_decay=True, - optim_type='SGD', - lr_piecewise_constant_decay=True, - learning_rate=0.2, # init lr for manually decay schedule + + # optim_type='SGD', + # lr_piecewise_constant_decay=True, + # learning_rate=0.2, # init lr for manually decay schedule + + optim_type='Adam', + lr_piecewise_constant_decay=False, + learning_rate=3e-3, + num_simulations=num_simulations, reanalyze_ratio=reanalyze_ratio, ssl_loss_weight=2, # default is 0 diff --git a/zoo/game_2048/config/rule_based_2048_config.py b/zoo/game_2048/config/rule_based_2048_config.py index 191411ead..17c24d79e 100644 --- a/zoo/game_2048/config/rule_based_2048_config.py +++ b/zoo/game_2048/config/rule_based_2048_config.py @@ -11,7 +11,6 @@ def rule_based_search(grid: np.array, fast_search: bool = True) -> int: - model1 = np.array([[16, 15, 14, 13], [9, 10, 11, 12], [8, 7, 6, 5], [1, 2, 2, 4]]) model2 = np.array([[16, 15, 12, 4], [14, 13, 11, 3], [10, 9, 8, 2], [7, 6, 5, 1]]) model3 = np.array([[16, 15, 14, 4], [13, 12, 11, 3], [10, 9, 8, 2], [7, 6, 5, 1]]) @@ -154,8 +153,6 @@ def generate(grid: np.array) -> np.array: return grid - - config = EasyDict(dict( env_name="game_2048_env_2048", save_replay_gif=False, @@ -163,10 +160,13 @@ def generate(grid: np.array) -> np.array: replay_path=None, act_scale=True, channel_last=True, - obs_type='array', - reward_normalize=True, - reward_scale=100, - max_tile=int(2**16), + obs_type='array', + raw_reward_type='raw', # 'merged_tiles_plus_log_max_tile_num' + # reward_type='merged_tiles_plus_log_max_tile_num', + # reward_normalize=True, + reward_normalize=False, + reward_norm_scale=100, + max_tile=int(2 ** 16), delay_reward_step=0, prob_random_agent=0., max_episode_steps=int(1e4), @@ -183,13 +183,13 @@ def generate(grid: np.array) -> np.array: step = 0 while True: # action = env.human_to_action() - print('='*20) + print('=' * 20) grid = obs.astype(np.int64) action = game_2048_env.random_action() action = rule_based_search(grid) - if(action == 1): - action=2 - elif(action == 2): + if action == 1: + action = 2 + elif action == 2: action = 1 obs, reward, done, info = game_2048_env.step(action) step += 1 diff --git a/zoo/game_2048/config/stochastic_muzero_2048_config.py b/zoo/game_2048/config/stochastic_muzero_2048_config.py index 79c04d577..638c6aad9 100644 --- a/zoo/game_2048/config/stochastic_muzero_2048_config.py +++ b/zoo/game_2048/config/stochastic_muzero_2048_config.py @@ -1,33 +1,45 @@ from easydict import EasyDict -env_name = 'game_2048' -action_space_size = 4 -chance_space_size = 4 + # ============================================================== # begin of the most frequently changed config specified by the user # ============================================================== -collector_env_num = 1 -n_episode = 1 -evaluator_env_num = 1 -num_simulations = 5 # TODO(pu):100 -update_per_collect = 3 -batch_size = 5 -max_env_step = int(1e6) +env_name = 'game_2048' +action_space_size = 4 +chance_space_size = 4 + +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 50 +update_per_collect = 200 +batch_size = 512 +max_env_step = int(1e8) reanalyze_ratio = 0. + +# collector_env_num = 1 +# n_episode = 1 +# evaluator_env_num = 1 +# num_simulations = 5 # TODO(pu):100 +# update_per_collect = 3 +# batch_size = 5 +# max_env_step = int(1e6) +# reanalyze_ratio = 0. # ============================================================== # end of the most frequently changed config specified by the user # ============================================================== game_2048_stochastic_muzero_config = dict( - exp_name=f'data_stochastic_mz_ctree/game_2048_stochastic_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_bs{batch_size}_sslw2_rew-morm-true_seed0', + exp_name=f'data_stochastic_mz_ctree/game_2048_stochastic_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_bs{batch_size}_seed0', env=dict( stop_value=int(1e6), env_name=env_name, obs_shape=(16, 4, 4), obs_type='dict_observation', - reward_normalize=True, + raw_reward_type='raw', # 'merged_tiles_plus_log_max_tile_num' + reward_normalize=False, reward_scale=100, - max_tile=int(2**16), # 2**11=2048, 2**16=65536 + max_tile=int(2 ** 16), # 2**11=2048, 2**16=65536 collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, n_evaluator_episode=evaluator_env_num, @@ -54,9 +66,14 @@ td_steps=10, discount_factor=0.999, manual_temperature_decay=True, - optim_type='SGD', - lr_piecewise_constant_decay=True, - learning_rate=0.2, # init lr for manually decay schedule + # optim_type='SGD', + # lr_piecewise_constant_decay=True, + # learning_rate=0.2, # init lr for manually decay schedule + + optim_type='Adam', + lr_piecewise_constant_decay=False, + learning_rate=3e-3, + num_simulations=num_simulations, reanalyze_ratio=reanalyze_ratio, ssl_loss_weight=2, # default is 0 diff --git a/zoo/game_2048/envs/game_2048_env.py b/zoo/game_2048/envs/game_2048_env.py index 4379a6c31..03b056ed0 100644 --- a/zoo/game_2048/envs/game_2048_env.py +++ b/zoo/game_2048/envs/game_2048_env.py @@ -1,5 +1,3 @@ -from __future__ import print_function - import copy import itertools import logging @@ -28,15 +26,16 @@ class Game2048Env(gym.Env): act_scale=True, channel_last=True, obs_type='raw_observation', # options=['raw_observation', 'dict_observation', 'array'] - reward_normalize=True, - reward_scale=100, - max_tile=int(2**16), # 2**11=2048, 2**16=65536 + reward_normalize=False, + reward_norm_scale=100, + reward_type='raw', # 'merged_tiles_plus_log_max_tile_num' + max_tile=int(2 ** 16), # 2**11=2048, 2**16=65536 delay_reward_step=0, prob_random_agent=0., max_episode_steps=int(1e6), is_collect=True, - ignore_legal_actions = True, - need_flatten = False, + ignore_legal_actions=True, + need_flatten=False, ) metadata = {'render.modes': ['human', 'ansi', 'rgb_array']} @@ -56,22 +55,23 @@ def __init__(self, cfg: dict) -> None: self._save_replay_count = 0 self.channel_last = cfg.channel_last self.obs_type = cfg.obs_type + self.reward_type = cfg.reward_type self.reward_normalize = cfg.reward_normalize - self.reward_scale = cfg.reward_scale + self.reward_norm_scale = cfg.reward_norm_scale + assert self.reward_type in ['raw', 'merged_tiles_plus_log_max_tile_num'] + assert self.reward_type=='raw' or (self.reward_type=='merged_tiles_plus_log_max_tile_num' and self.reward_normalize==False) self.max_tile = cfg.max_tile self.max_episode_steps = cfg.max_episode_steps self.is_collect = cfg.is_collect self.ignore_legal_actions = cfg.ignore_legal_actions self.need_flatten = cfg.need_flatten self.chance = 0 - + self.chance_space_size = 16 # 32 for 2 and 4, 16 for 2 + self.max_tile_num = 0 self.size = 4 self.w = self.size self.h = self.size self.squares = self.size * self.size - - self.max_value = 2 - self.episode_return = 0 # Members for gym implementation: self._action_space = spaces.Discrete(4) @@ -80,39 +80,14 @@ def __init__(self, cfg: dict) -> None: self.set_illegal_move_reward(0.) self.set_max_tile(max_tile=self.max_tile) - if self.reward_normalize: - self._reward_range = (0., self.max_tile) - else: - self._reward_range = (0., self.max_tile) + self._reward_range = (0., self.max_tile) - # TODO(pu): why + # for render self.grid_size = 70 # Initialise the random seed of the gym environment. self.seed() - def seed(self, seed=None, seed1=None): - """Set the random seed for the gym environment.""" - self.np_random, seed = seeding.np_random(seed) - return [seed] - - def set_illegal_move_reward(self, reward): - """Define the reward/penalty for performing an illegal move. Also need - to update the reward range for this.""" - # Guess that the maximum reward is also 2**squares though you'll probably never get that. - # (assume that illegal move reward is the lowest value that can be returned - # TODO: check that this is correct - self.illegal_move_reward = reward - self.reward_range = (self.illegal_move_reward, float(2 ** self.squares)) - - def set_max_tile(self, max_tile: int = 2048): - """ - Define the maximum tile that will end the game (e.g. 2048). None means no limit. - This does not affect the state returned. - """ - assert max_tile is None or isinstance(max_tile, int) - self.max_tile = max_tile - def reset(self): """Reset the game board-matrix and add 2 tiles.""" self.episode_length = 0 @@ -120,7 +95,6 @@ def reset(self): self.episode_return = 0 self._final_eval_reward = 0.0 self.should_done = False - self.max_value = 2 logging.debug("Adding tiles") # TODO(pu): why add_tiles twice? @@ -130,7 +104,7 @@ def reset(self): action_mask = np.zeros(4, 'int8') action_mask[self.legal_actions] = 1 - observation = encoding_board(self.board) + observation = encode_board(self.board) observation = observation.astype(np.float32) assert observation.shape == (4, 4, 16) @@ -145,7 +119,7 @@ def reset(self): if self.obs_type == 'dict_observation': observation = {'observation': observation, 'action_mask': action_mask, 'to_play': -1, 'chance': self.chance} elif self.obs_type == 'array': - observation = self.board + observation = self.board else: observation = observation return observation @@ -153,34 +127,39 @@ def reset(self): def step(self, action): """Perform one step of the game. This involves moving and adding a new tile.""" self.episode_length += 1 - info = {'illegal_move': False} if action not in self.legal_actions: - raise IllegalActionError(f"You input illegal action: {action}, the legal_actions are {self.legal_actions}. ") - - empty_num1 = len(self.get_empty_location()) - reward_eval = float(self.move(action)) - empty_num2 = len(self.get_empty_location()) - reward_collect = float(empty_num2 - empty_num1) - #reward_collect = float(empty_num1 - empty_num2) - max_num = np.max(self.board) - if max_num > self.max_value: - reward_collect += np.log2(max_num) * 0.1 - self.max_value = max_num - self.episode_return += reward_eval - assert reward_eval <= 2 ** (self.w * self.h) + raise IllegalActionError( + f"You input illegal action: {action}, the legal_actions are {self.legal_actions}. ") + + if self.reward_type == 'merged_tiles_plus_log_max_tile_num': + empty_num1 = len(self.get_empty_location()) + raw_reward = float(self.move(action)) + if self.reward_type == 'merged_tiles_plus_log_max_tile_num': + empty_num2 = len(self.get_empty_location()) + num_of_merged_tiles = float(empty_num2 - empty_num1) + reward_merged_tiles_plus_log_max_tile_num = num_of_merged_tiles + max_tile_num = self.highest() + if max_tile_num > self.max_tile_num: + reward_merged_tiles_plus_log_max_tile_num += np.log2(max_tile_num) * 0.1 + self.max_tile_num = max_tile_num + + self.episode_return += raw_reward + assert raw_reward <= 2 ** (self.w * self.h) self.add_random_2_4_tile() done = self.is_end() - reward_collect = float(reward_collect) - reward_eval = float(reward_eval) + if self.reward_type == 'merged_tiles_plus_log_max_tile_num': + reward_merged_tiles_plus_log_max_tile_num = float(reward_merged_tiles_plus_log_max_tile_num) + elif self.reward_type == 'raw': + raw_reward = float(raw_reward) if self.episode_length >= self.max_episode_steps: # print("episode_length: {}".format(self.episode_length)) done = True - observation = encoding_board(self.board) + observation = encode_board(self.board) observation = observation.astype(np.float32) - + assert observation.shape == (4, 4, 16) if not self.channel_last: @@ -188,7 +167,7 @@ def step(self, action): # (W, H, C) -> (C, W, H) # e.g. (4, 4, 16) -> (16, 4, 4) observation = np.transpose(observation, [2, 0, 1]) - + if self.need_flatten: observation = observation.reshape(-1) action_mask = np.zeros(4, 'int8') @@ -197,28 +176,103 @@ def step(self, action): if self.obs_type == 'dict_observation': observation = {'observation': observation, 'action_mask': action_mask, 'to_play': -1, 'chance': self.chance} elif self.obs_type == 'array': - observation = self.board + observation = self.board else: observation = observation if self.reward_normalize: - reward_normalize = reward_collect - self._final_eval_reward += reward_normalize - reward = reward_collect + reward_normalize = raw_reward / self.reward_norm_scale + reward = reward_normalize else: - self._final_eval_reward += reward_eval - reward = reward_eval - reward = to_ndarray([reward]).astype(np.float32) + reward = raw_reward + + self._final_eval_reward += raw_reward - info = {"raw_reward": reward_eval, "max_tile": self.highest(), 'highest': self.highest()} + if self.reward_type == 'merged_tiles_plus_log_max_tile_num': + reward = to_ndarray([reward_merged_tiles_plus_log_max_tile_num]).astype(np.float32) + elif self.reward_type == 'raw': + reward = to_ndarray([reward]).astype(np.float32) + info = {"raw_reward": raw_reward, "max_tile": self.highest(), 'highest': self.highest()} if done: info['eval_episode_return'] = self._final_eval_reward - if self.reward_normalize: - return BaseEnvTimestep(observation, reward, done, info) + return BaseEnvTimestep(observation, reward, done, info) + + def move(self, direction, trial=False): + """ + Overview: + Perform one move of the game. Shift things to one side then, + combine. directions 0, 1, 2, 3 are up, right, down, left. + Returns the reward that [would have] got. + Arguments: + - direction (:obj:`int`): The direction to move. + - trial (:obj:`bool`): Whether this is a trial move. + """ + if not trial: + if direction == 0: + logging.debug("Up") + elif direction == 1: + logging.debug("Right") + elif direction == 2: + logging.debug("Down") + elif direction == 3: + logging.debug("Left") + + changed = False + move_reward = 0 + dir_div_two = int(direction / 2) + dir_mod_two = int(direction % 2) + # 0 for towards up or left, 1 for towards bottom or right + shift_direction = dir_mod_two ^ dir_div_two + + # Construct a range for extracting row/column into a list + rx = list(range(self.w)) + ry = list(range(self.h)) + + if dir_mod_two == 0: + # Up or down, split into columns + for y in range(self.h): + old = [self.get(x, y) for x in rx] + (new, ms) = self.shift(old, shift_direction) + move_reward += ms + if old != new: + changed = True + if not trial: + for x in rx: + self.set(x, y, new[x]) else: - return BaseEnvTimestep(observation, reward, done, info) + # Left or right, split into rows + for x in range(self.w): + old = [self.get(x, y) for y in ry] + (new, ms) = self.shift(old, shift_direction) + move_reward += ms + if old != new: + changed = True + if not trial: + for y in ry: + self.set(x, y, new[y]) + # if not changed: + # raise IllegalMove + + return move_reward + + def set_illegal_move_reward(self, reward): + """Define the reward/penalty for performing an illegal move. Also need + to update the reward range for this.""" + # Guess that the maximum reward is also 2**squares though you'll probably never get that. + # (assume that illegal move reward is the lowest value that can be returned + # TODO: check that this is correct + self.illegal_move_reward = reward + self.reward_range = (self.illegal_move_reward, float(2 ** self.squares)) + + def set_max_tile(self, max_tile: int = 2048): + """ + Define the maximum tile that will end the game (e.g. 2048). None means no limit. + This does not affect the state returned. + """ + assert max_tile is None or isinstance(max_tile, int) + self.max_tile = max_tile def render(self, mode='human'): if mode == 'rgb_array': @@ -275,20 +329,25 @@ def add_random_2_4_tile(self): """Add a tile with value 2 or 4 with different probabilities.""" possible_tiles = np.array([2, 4]) tile_probabilities = np.array([0.9, 0.1]) - val = self.np_random.choice(possible_tiles, 1, p=tile_probabilities)[0] + tile_val = self.np_random.choice(possible_tiles, 1, p=tile_probabilities)[0] empty_location = self.get_empty_location() # assert empty_location.shape[0] if empty_location.shape[0] == 0: - self.should_done = True - return + self.should_done = True + return empty_idx = self.np_random.choice(empty_location.shape[0]) empty = empty_location[empty_idx] - logging.debug("Adding %s at %s", val, (empty[0], empty[1])) - val_chance_cum = 0 - # if val == 4: - # val_chance_cum = 16 - self.chance = val_chance_cum + 4 * empty[0] + empty[1] - self.set(empty[0], empty[1], val) + logging.debug("Adding %s at %s", tile_val, (empty[0], empty[1])) + + if self.chance_space_size == 16: + self.chance = 4 * empty[0] + empty[1] + elif self.chance_space_size == 32: + if tile_val == 2: + self.chance = 4 * empty[0] + empty[1] + elif tile_val == 4: + self.chance = 16 + 4 * empty[0] + empty[1] + + self.set(empty[0], empty[1], tile_val) def get(self, x, y): """Get the value of one square.""" @@ -306,63 +365,6 @@ def highest(self): """Report the highest tile on the board.""" return np.max(self.board) - def move(self, direction, trial=False): - """ - Overview: - Perform one move of the game. Shift things to one side then, - combine. directions 0, 1, 2, 3 are up, right, down, left. - Returns the reward that [would have] got. - Arguments: - - direction (:obj:`int`): The direction to move. - - trial (:obj:`bool`): Whether this is a trial move. - """ - if not trial: - if direction == 0: - logging.debug("Up") - elif direction == 1: - logging.debug("Right") - elif direction == 2: - logging.debug("Down") - elif direction == 3: - logging.debug("Left") - - changed = False - move_reward = 0 - dir_div_two = int(direction / 2) - dir_mod_two = int(direction % 2) - # 0 for towards up or left, 1 for towards bottom or right - shift_direction = dir_mod_two ^ dir_div_two - - # Construct a range for extracting row/column into a list - rx = list(range(self.w)) - ry = list(range(self.h)) - - if dir_mod_two == 0: - # Up or down, split into columns - for y in range(self.h): - old = [self.get(x, y) for x in rx] - (new, ms) = self.shift(old, shift_direction) - move_reward += ms - if old != new: - changed = True - if not trial: - for x in rx: - self.set(x, y, new[x]) - else: - # Left or right, split into rows - for x in range(self.w): - old = [self.get(x, y) for y in ry] - (new, ms) = self.shift(old, shift_direction) - move_reward += ms - if old != new: - changed = True - if not trial: - for y in ry: - self.set(x, y, new[y]) - # if not changed: - # raise IllegalMove - - return move_reward @property def legal_actions(self): @@ -375,7 +377,7 @@ def legal_actions(self): - legal_actions (:obj:`list`): The legal actions. """ if self.ignore_legal_actions: - return [0,1,2,3] + return [0, 1, 2, 3] legal_actions = [] for direction in range(4): changed = False @@ -479,6 +481,11 @@ def set_board(self, new_board): """Set the whole board-matrix, useful for testing.""" self.board = new_board + def seed(self, seed=None, seed1=None): + """Set the random seed for the gym environment.""" + self.np_random, seed = seeding.np_random(seed) + return [seed] + def random_action(self) -> np.ndarray: random_action = self.action_space.sample() if isinstance(random_action, np.ndarray): @@ -503,7 +510,6 @@ def reward_space(self) -> gym.spaces.Space: def create_collector_env_cfg(cfg: dict) -> List[dict]: collector_env_num = cfg.pop('collector_env_num') cfg = copy.deepcopy(cfg) - # cfg.reward_normalize = True # when collect data, sometimes we need to normalize the reward # reward_normalize is determined by the config. cfg.is_collect = True @@ -532,28 +538,35 @@ def pairwise(iterable): class IllegalMove(Exception): pass + class IllegalActionError(Exception): pass -def encoding_board(flat, num_of_template_tiles=16): + +def encode_board(flat_board, num_of_template_tiles=16): """ Overview: - Convert an [4, 4] raw board into [4, 4, num_of_template_tiles] one-hot encoding. + This function converts a [4, 4] raw game board into a [4, 4, num_of_template_tiles] one-hot encoded board. Arguments: - - board (:obj:`np.ndarray`): the raw board - - num_of_template_tiles (:obj:`int`): the number of template_tiles + - flat_board (:obj:`np.ndarray`): The raw game board, expected to be a 2D numpy array. + - num_of_template_tiles (:obj:`int`): The number of unique tiles to consider in the encoding, + default value is 16. Returns: - - one_hot_board (:obj:`np.ndarray`): the one-hot encoding board + - one_hot_board (:obj:`np.ndarray`): The one-hot encoded game board. """ - # TODO(pu): the more elegant one-hot encoding implementation - # template_tiles is what each layer represents - # template_tiles = 2 ** (np.arange(num_of_template_tiles, dtype=int) + 1) - template_tiles = 2 ** (np.arange(num_of_template_tiles, dtype=int)) - template_tiles[0] = 0 - # layered is the flat board repeated num_of_template_tiles times - layered = np.repeat(flat[:, :, np.newaxis], num_of_template_tiles, axis=-1) - - # Now set the values in the board to 1 or zero depending on whether they match template_tiles. - # template_tiles is broadcast across a number of axes - one_hot_board = np.where(layered == template_tiles, 1, 0) - return one_hot_board \ No newline at end of file + # Generate a sequence of powers of 2, corresponding to the unique tile values. + # In the game, tile values are powers of 2. So, each unique tile is represented by 2 raised to some power. + # The first tile is considered as 0 (empty tile). + tile_values = 2 ** np.arange(num_of_template_tiles, dtype=int) + tile_values[0] = 0 # The first tile represents an empty slot, so set its value to 0. + + # Create a 3D array from the 2D input board by repeating it along a new axis. + # This creates a 'layered' view of the board, where each layer corresponds to one unique tile value. + layered_board = np.repeat(flat_board[:, :, np.newaxis], num_of_template_tiles, axis=-1) + + # Perform the one-hot encoding: + # For each layer of the 'layered_board', mark the positions where the tile value in the 'flat_board' + # matches the corresponding value in 'tile_values'. If a match is found, mark it as 1 (True), else 0 (False). + one_hot_board = (layered_board == tile_values).astype(int) + + return one_hot_board From 6bd0310adab5bb602d8affd8f16d5efb1fd2152a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= <2402552459@qq.com> Date: Wed, 2 Aug 2023 21:06:44 +0800 Subject: [PATCH 11/28] polish(pu): polish chance encoder --- lzero/model/stochastic_muzero_model.py | 102 ++++++++++++++---- lzero/policy/stochastic_muzero.py | 6 +- zoo/game_2048/config/muzero_2048_config.py | 34 +++--- .../config/stochastic_muzero_2048_config.py | 17 +-- 4 files changed, 109 insertions(+), 50 deletions(-) diff --git a/lzero/model/stochastic_muzero_model.py b/lzero/model/stochastic_muzero_model.py index 0083a6ae6..da23a0fbb 100644 --- a/lzero/model/stochastic_muzero_model.py +++ b/lzero/model/stochastic_muzero_model.py @@ -126,7 +126,7 @@ def __init__( downsample, ) - self.encoder = Encoder_function( + self.encoder = ChanceEncoder( observation_shape, chance_space_size ) self.dynamics_network = DynamicsNetwork( @@ -292,7 +292,7 @@ def _representation(self, observation: torch.Tensor) -> torch.Tensor: latent_state = renormalize(latent_state) return latent_state - def _encode_vqvae(self, observation: torch.Tensor): + def chance_encode(self, observation: torch.Tensor): output = self.encoder(observation) return output @@ -787,9 +787,9 @@ def forward(self, latent_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tenso return policy, value -class ImgNet(nn.Module): +class ChanceEncoderBackbone(nn.Module): def __init__(self, observation_space_dimensions, table_vec_dim=4): - super(ImgNet, self).__init__() + super(ChanceEncoderBackbone, self).__init__() self.conv1 = nn.Conv2d(observation_space_dimensions[0] * 2, 32, 3, padding=1) self.conv2 = nn.Conv2d(32, 64, 3, padding=1) self.fc1 = nn.Linear(64 * observation_space_dimensions[1] * observation_space_dimensions[2], 128) @@ -799,7 +799,6 @@ def __init__(self, observation_space_dimensions, table_vec_dim=4): def forward(self, x): x = torch.relu(self.conv1(x)) x = torch.relu(self.conv2(x)) - # x = x.view(-1, 64 * 4 * 4) x = x.view(x.shape[0], -1) x = torch.relu(self.fc1(x)) x = torch.relu(self.fc2(x)) @@ -807,41 +806,100 @@ def forward(self, x): return x -class Encoder_function(nn.Module): - def __init__(self, - observation_space_dimensions, - action_dimension): +class ChanceEncoder(nn.Module): + def __init__(self, observation_space_dimensions, action_dimension): super().__init__() + # Specify the action space for the model self.action_space = action_dimension - self.encoder = ImgNet(observation_space_dimensions, action_dimension) + # Define the encoder, which transforms observations into a latent space + self.encoder = ChanceEncoderBackbone(observation_space_dimensions, action_dimension) + # Using the Straight Through Estimator method for backpropagation self.onehot_argmax = StraightThroughEstimator() def forward(self, o_i): - # https://openreview.net/pdf?id=X6D9bAHhBQ1 [page:5 chance outcome] - # c_e_t = torch.nn.Softmax(-1)(self.encoder(o_i)) - c_e_t = self.encoder(o_i) - # c_t= torch.zeros_like(c_e_t).scatter_(-1, torch.argmax(c_e_t, dim=-1,keepdim=True), 1.) - c_t = self.onehot_argmax(c_e_t) - return c_t, c_e_t + """ + Forward method for the ChanceEncoder. This method takes an observation + and applies the encoder to transform it to a latent space. Then applies the + StraightThroughEstimator to this encoding. + + References: + Planning in Stochastic Environments with a Learned Model (ICLR 2022), page 5, + Chance Outcomes section. + + Args: + o_i (Tensor): Observation tensor. + + Returns: + chance_t (Tensor): Transformed tensor after applying one-hot argmax. + chance_encoding_t (Tensor): Encoding of the input observation tensor. + """ + # Apply the encoder to the observation + chance_encoding_t = self.encoder(o_i) + # Apply one-hot argmax to the encoding + chance_t = self.onehot_argmax(chance_encoding_t) + return chance_t, chance_encoding_t class StraightThroughEstimator(nn.Module): def __init__(self): - super(StraightThroughEstimator, self).__init__() + super().__init__() def forward(self, x): - x = Onehot_argmax.apply(x) + """ + Forward method for the StraightThroughEstimator. This applies the one-hot argmax + function to the input tensor. + + Args: + x (Tensor): Input tensor. + + Returns: + Tensor: Transformed tensor after applying one-hot argmax. + """ + # Apply one-hot argmax to the input + x = OnehotArgmax.apply(x) return x -# straight-through estimator is used during the backward to allow the gradients to flow only to the encoder during the backpropagation. -class Onehot_argmax(torch.autograd.Function): - # more information at : https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html +class OnehotArgmax(torch.autograd.Function): + """ + Custom PyTorch function for one-hot argmax. This function transforms the input tensor + into a one-hot tensor where the index with the maximum value in the original tensor is + set to 1 and all other indices are set to 0. It allows gradients to flow to the encoder + during backpropagation. + + For more information, refer to: + https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html + """ + @staticmethod def forward(ctx, input): - # since the codebook is constant ,we can just use a transformation. no need to create a codebook and matmul c_e_t and codebook for argmax + """ + Forward method for the one-hot argmax function. This method transforms the input + tensor into a one-hot tensor. + + Args: + ctx (context): A context object that can be used to stash information for + backward computation. + input (Tensor): Input tensor. + + Returns: + Tensor: One-hot tensor. + """ + # Transform the input tensor to a one-hot tensor return torch.zeros_like(input).scatter_(-1, torch.argmax(input, dim=-1, keepdim=True), 1.) @staticmethod def backward(ctx, grad_output): + """ + Backward method for the one-hot argmax function. This method allows gradients + to flow to the encoder during backpropagation. + + Args: + ctx (context): A context object that was stashed in the forward pass. + grad_output (Tensor): The gradient of the output tensor. + + Returns: + Tensor: The gradient of the input tensor. + """ return grad_output + diff --git a/lzero/policy/stochastic_muzero.py b/lzero/policy/stochastic_muzero.py index 13dc8fdd9..8767e93f6 100644 --- a/lzero/policy/stochastic_muzero.py +++ b/lzero/policy/stochastic_muzero.py @@ -371,9 +371,9 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # concat consecutive frames to calculate ground truth chance former_frame = encoder_image_list[step_i] latter_frame = encoder_image_list[step_i + 1] - concat_frame = torch.cat((former_frame, latter_frame), dim=1) + concat_frames = torch.cat((former_frame, latter_frame), dim=1) - chance_code, encode_output = self._learn_model._encode_vqvae(concat_frame) + chance_code, chance_encoding = self._learn_model.chance_encode(concat_frames) chance_code_long = torch.argmax(chance_code, dim=1).long().unsqueeze(-1) # unroll with the dynamics function: predict the next ``latent_state``, ``reward``, @@ -420,7 +420,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # ============================================================== policy_loss += cross_entropy_loss(policy_logits, target_policy[:, step_i + 1]) afterstate_policy_loss += cross_entropy_loss(afterstate_policy_logits, chance_code) - commitment_loss += cross_entropy_loss(encode_output, chance_code) + commitment_loss += cross_entropy_loss(chance_encoding, chance_code) afterstate_value_loss += cross_entropy_loss(afterstate_value, target_value_categorical[:, step_i]) value_loss += cross_entropy_loss(value, target_value_categorical[:, step_i + 1]) diff --git a/zoo/game_2048/config/muzero_2048_config.py b/zoo/game_2048/config/muzero_2048_config.py index 50afd7976..3f087bdc1 100644 --- a/zoo/game_2048/config/muzero_2048_config.py +++ b/zoo/game_2048/config/muzero_2048_config.py @@ -7,23 +7,23 @@ env_name = 'game_2048' action_space_size = 4 -# collector_env_num = 8 -# n_episode = 8 -# evaluator_env_num = 3 -# num_simulations = 50 -# update_per_collect = 200 -# batch_size = 512 -# max_env_step = int(1e8) -# reanalyze_ratio = 0. - -collector_env_num = 1 -n_episode = 1 -evaluator_env_num = 1 -num_simulations = 5 -update_per_collect = 3 -batch_size = 5 -max_env_step = int(1e6) +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 50 +update_per_collect = 200 +batch_size = 512 +max_env_step = int(5e6) reanalyze_ratio = 0. + +# collector_env_num = 1 +# n_episode = 1 +# evaluator_env_num = 1 +# num_simulations = 5 +# update_per_collect = 3 +# batch_size = 5 +# max_env_step = int(1e6) +# reanalyze_ratio = 0. # ============================================================== # end of the most frequently changed config specified by the user # ============================================================== @@ -62,7 +62,7 @@ td_steps=10, discount_factor=0.999, manual_temperature_decay=True, - + threshold_training_steps_for_final_temperature=int(1e5), # optim_type='SGD', # lr_piecewise_constant_decay=True, # learning_rate=0.2, # init lr for manually decay schedule diff --git a/zoo/game_2048/config/stochastic_muzero_2048_config.py b/zoo/game_2048/config/stochastic_muzero_2048_config.py index 638c6aad9..a24026f82 100644 --- a/zoo/game_2048/config/stochastic_muzero_2048_config.py +++ b/zoo/game_2048/config/stochastic_muzero_2048_config.py @@ -17,14 +17,14 @@ max_env_step = int(1e8) reanalyze_ratio = 0. -# collector_env_num = 1 -# n_episode = 1 -# evaluator_env_num = 1 -# num_simulations = 5 # TODO(pu):100 -# update_per_collect = 3 -# batch_size = 5 -# max_env_step = int(1e6) -# reanalyze_ratio = 0. +collector_env_num = 1 +n_episode = 1 +evaluator_env_num = 1 +num_simulations = 5 +update_per_collect = 3 +batch_size = 5 +max_env_step = int(1e6) +reanalyze_ratio = 0. # ============================================================== # end of the most frequently changed config specified by the user # ============================================================== @@ -66,6 +66,7 @@ td_steps=10, discount_factor=0.999, manual_temperature_decay=True, + threshold_training_steps_for_final_temperature=int(1e5), # optim_type='SGD', # lr_piecewise_constant_decay=True, # learning_rate=0.2, # init lr for manually decay schedule From f85aec33e6394f83872e703037f79433cc7ae9d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= <2402552459@qq.com> Date: Thu, 3 Aug 2023 22:12:18 +0800 Subject: [PATCH 12/28] fix(pu): fix chance encoder related loss --- .../buffer/game_buffer_stochastic_muzero.py | 26 ++++++---- lzero/mcts/buffer/game_segment.py | 22 +++++--- lzero/model/stochastic_muzero_model.py | 8 +-- lzero/policy/stochastic_muzero.py | 51 ++++++++++++------- lzero/worker/muzero_collector.py | 37 +++++++++----- ...ochastic_muzero_2048_true_chance_config.py | 15 +++--- 6 files changed, 102 insertions(+), 57 deletions(-) diff --git a/lzero/mcts/buffer/game_buffer_stochastic_muzero.py b/lzero/mcts/buffer/game_buffer_stochastic_muzero.py index 3f93e695b..d8b7dbf4c 100644 --- a/lzero/mcts/buffer/game_buffer_stochastic_muzero.py +++ b/lzero/mcts/buffer/game_buffer_stochastic_muzero.py @@ -111,7 +111,8 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time_list = orig_data batch_size = len(batch_index_list) obs_list, action_list, mask_list = [], [], [] - chance_list = [] + if self._cfg.use_ture_chance_label_in_chance_encoder: + chance_list = [] # prepare the inputs of a batch for i in range(batch_size): game = game_segment_list[i] @@ -119,8 +120,9 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: actions_tmp = game.action_segment[pos_in_game_segment:pos_in_game_segment + self._cfg.num_unroll_steps].tolist() - chances_tmp = game.chance_segment[1 + pos_in_game_segment:1 + pos_in_game_segment + - self._cfg.num_unroll_steps].tolist() + if self._cfg.use_ture_chance_label_in_chance_encoder: + chances_tmp = game.chance_segment[1 + pos_in_game_segment:1 + pos_in_game_segment + + self._cfg.num_unroll_steps].tolist() # add mask for invalid actions (out of trajectory) mask_tmp = [1. for i in range(len(actions_tmp))] mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps - len(mask_tmp))] @@ -130,10 +132,11 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: np.random.randint(0, game.action_space_size) for _ in range(self._cfg.num_unroll_steps - len(actions_tmp)) ] - chances_tmp += [ - np.random.randint(0, game.action_space_size) - for _ in range(self._cfg.num_unroll_steps - len(chances_tmp)) - ] + if self._cfg.use_ture_chance_label_in_chance_encoder: + chances_tmp += [ + np.random.randint(0, game.action_space_size) + for _ in range(self._cfg.num_unroll_steps - len(chances_tmp)) + ] # obtain the input observations # pad if length of obs in game_segment is less than stack+num_unroll_steps # e.g. stack+num_unroll_steps 4+5 @@ -144,13 +147,18 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: ) action_list.append(actions_tmp) mask_list.append(mask_tmp) - chance_list.append(chances_tmp) + if self._cfg.use_ture_chance_label_in_chance_encoder: + chance_list.append(chances_tmp) # formalize the input observations obs_list = prepare_observation(obs_list, self._cfg.model.model_type) # formalize the inputs of a batch - current_batch = [obs_list, action_list, mask_list, batch_index_list, weights_list, make_time_list, chance_list] + if self._cfg.use_ture_chance_label_in_chance_encoder: + current_batch = [obs_list, action_list, mask_list, batch_index_list, weights_list, make_time_list, + chance_list] + else: + current_batch = [obs_list, action_list, mask_list, batch_index_list, weights_list, make_time_list] for i in range(len(current_batch)): current_batch[i] = np.asarray(current_batch[i]) diff --git a/lzero/mcts/buffer/game_segment.py b/lzero/mcts/buffer/game_segment.py index b0d9cfdb0..cb2eaeaae 100644 --- a/lzero/mcts/buffer/game_segment.py +++ b/lzero/mcts/buffer/game_segment.py @@ -64,7 +64,6 @@ def __init__(self, action_space: int, game_segment_length: int = 200, config: Ea self.action_mask_segment = [] self.to_play_segment = [] - self.chance_segment = [] self.target_values = [] self.target_rewards = [] @@ -74,6 +73,9 @@ def __init__(self, action_space: int, game_segment_length: int = 200, config: Ea if self.config.sampled_algo: self.root_sampled_actions = [] + if self.config.use_ture_chance_label_in_chance_encoder: + self.chance_segment = [] + def get_unroll_obs(self, timestep: int, num_unroll_steps: int = 0, padding: bool = False) -> np.ndarray: """ @@ -130,7 +132,7 @@ def append( reward: np.ndarray, action_mask: np.ndarray = None, to_play: int = -1, - chance: np.ndarray=0, + chance: np.ndarray = 0, ) -> None: """ Overview: @@ -142,11 +144,12 @@ def append( self.action_mask_segment.append(action_mask) self.to_play_segment.append(to_play) - self.chance_segment.append(chance) + if self.config.use_ture_chance_label_in_chance_encoder: + self.chance_segment.append(chance) def pad_over( self, next_segment_observations: List, next_segment_rewards: List, next_segment_root_values: List, - next_segment_child_visits: List, next_chances: List = None,next_segment_improved_policy: List = None, + next_segment_child_visits: List, next_segment_improved_policy: List = None, next_chances: List = None, ) -> None: """ Overview: @@ -187,8 +190,9 @@ def pad_over( if self.config.gumbel_algo: for improved_policy in next_segment_improved_policy: self.improved_policy_probs.append(improved_policy) - for chances in next_chances: - self.chance_segment.append(chances) + if self.config.use_ture_chance_label_in_chance_encoder: + for chances in next_chances: + self.chance_segment.append(chances) def get_targets(self, timestep: int) -> Tuple: """ @@ -258,7 +262,8 @@ def game_segment_to_array(self) -> None: self.action_mask_segment = np.array(self.action_mask_segment) self.to_play_segment = np.array(self.to_play_segment) - self.chance_segment = np.array(self.chance_segment) + if self.config.use_ture_chance_label_in_chance_encoder: + self.chance_segment = np.array(self.chance_segment) def reset(self, init_observations: np.ndarray) -> None: """ @@ -277,7 +282,8 @@ def reset(self, init_observations: np.ndarray) -> None: self.action_mask_segment = [] self.to_play_segment = [] - self.chance_segment = [] + if self.config.use_ture_chance_label_in_chance_encoder: + self.chance_segment = [] assert len(init_observations) == self.frame_stack_num diff --git a/lzero/model/stochastic_muzero_model.py b/lzero/model/stochastic_muzero_model.py index da23a0fbb..a40f027e1 100644 --- a/lzero/model/stochastic_muzero_model.py +++ b/lzero/model/stochastic_muzero_model.py @@ -788,13 +788,13 @@ def forward(self, latent_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tenso class ChanceEncoderBackbone(nn.Module): - def __init__(self, observation_space_dimensions, table_vec_dim=4): + def __init__(self, observation_space_dimensions, chance_encoding_dim=4): super(ChanceEncoderBackbone, self).__init__() self.conv1 = nn.Conv2d(observation_space_dimensions[0] * 2, 32, 3, padding=1) self.conv2 = nn.Conv2d(32, 64, 3, padding=1) self.fc1 = nn.Linear(64 * observation_space_dimensions[1] * observation_space_dimensions[2], 128) self.fc2 = nn.Linear(128, 64) - self.fc3 = nn.Linear(64, table_vec_dim) + self.fc3 = nn.Linear(64, chance_encoding_dim) def forward(self, x): x = torch.relu(self.conv1(x)) @@ -836,8 +836,8 @@ def forward(self, o_i): # Apply the encoder to the observation chance_encoding_t = self.encoder(o_i) # Apply one-hot argmax to the encoding - chance_t = self.onehot_argmax(chance_encoding_t) - return chance_t, chance_encoding_t + chance_onehot_t = self.onehot_argmax(chance_encoding_t) + return chance_encoding_t, chance_onehot_t class StraightThroughEstimator(nn.Module): diff --git a/lzero/policy/stochastic_muzero.py b/lzero/policy/stochastic_muzero.py index 20acc6f81..16cd70ef2 100644 --- a/lzero/policy/stochastic_muzero.py +++ b/lzero/policy/stochastic_muzero.py @@ -147,6 +147,8 @@ class StochasticMuZeroPolicy(Policy): # (float) The fixed temperature value for MCTS action selection, which is used to control the exploration. # The larger the value, the more exploration. This value is only used when manual_temperature_decay=False. fixed_temperature_value=0.25, + # (bool) Whether to use the true chance in MCTS. + use_ture_chance_label_in_chance_encoder=False, # ****** Priority ****** # (bool) Whether to use priority when sampling training data from the buffer. @@ -273,20 +275,23 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in self._target_model.train() current_batch, target_batch = data - obs_batch_orig, action_batch, mask_batch, indices, weights, make_time, chance_batch = current_batch + if self._cfg.use_ture_chance_label_in_chance_encoder: + obs_batch_orig, action_batch, mask_batch, indices, weights, make_time, chance_batch = current_batch + else: + obs_batch_orig, action_batch, mask_batch, indices, weights, make_time = current_batch target_reward, target_value, target_policy = target_batch - if self._cfg.explicit_chance_label: - chance_batch = torch.LongTensor(chance_batch).to(self._cfg.device) - chance_batch =torch.nn.functional.one_hot(chance_batch, self._cfg.model.chance_space_size) + if self._cfg.use_ture_chance_label_in_chance_encoder: + chance_batch = torch.Tensor(chance_batch).to(self._cfg.device) + chance_one_hot_batch = torch.nn.functional.one_hot(chance_batch.long(), self._cfg.model.chance_space_size) obs_batch, obs_target_batch = prepare_obs(obs_batch_orig, self._cfg) - encoder_image_list = [] - encoder_image_list.append(obs_batch) + obs_list_for_chance_encoder = [] + obs_list_for_chance_encoder.append(obs_batch) for i in range(self._cfg.num_unroll_steps): beg_index = self._cfg.model.image_channel * i end_index = self._cfg.model.image_channel * (i + self._cfg.model.frame_stack_num) - encoder_image_list.append(obs_target_batch[:, beg_index:end_index, :, :]) + obs_list_for_chance_encoder.append(obs_target_batch[:, beg_index:end_index, :, :]) # do augmentations if self._cfg.use_augmentation: @@ -372,19 +377,23 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in ) after_state, afterstate_reward, afterstate_value, afterstate_policy_logits = mz_network_output_unpack(network_output) - # concat consecutive frames to calculate ground truth chance - former_frame = encoder_image_list[step_i] - latter_frame = encoder_image_list[step_i + 1] + # concat consecutive frames to predict chance + former_frame = obs_list_for_chance_encoder[step_i] + latter_frame = obs_list_for_chance_encoder[step_i + 1] concat_frame = torch.cat((former_frame, latter_frame), dim=1) - chance_code, encode_output = self._learn_model._encode_vqvae(concat_frame) - if self._cfg.explicit_chance_label: - chance_code = chance_batch[:, step_i] - chance_code_long = torch.argmax(chance_code, dim=1).long().unsqueeze(-1) + chance_encoding, chance_one_hot = self._learn_model.chance_encode(concat_frame) + if self._cfg.use_ture_chance_label_in_chance_encoder: + true_chance_code = chance_batch[:, step_i] + chance_code = true_chance_code + true_chance_one_hot = chance_one_hot_batch[:, step_i] + else: + chance_code = torch.argmax(chance_encoding, dim=1).long().unsqueeze(-1) # unroll with the dynamics function: predict the next ``latent_state``, ``reward``, # given current ``after_state`` and ``chance_long``. # And then predict policy_logits and value with the prediction function. - network_output = self._learn_model.recurrent_inference(after_state, chance_code_long, afterstate=True) + network_output = self._learn_model.recurrent_inference(after_state, chance_code, afterstate=True) + latent_state, reward, value, policy_logits = mz_network_output_unpack(network_output) # transform the scaled value or its categorical representation to its original value, @@ -424,9 +433,15 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # NOTE: the +=. # ============================================================== policy_loss += cross_entropy_loss(policy_logits, target_policy[:, step_i + 1]) - afterstate_policy_loss += cross_entropy_loss(afterstate_policy_logits, chance_code) - # commitment_loss += cross_entropy_loss(encode_output, chance_code) - commitment_loss += torch.nn.MSELoss()(encode_output, chance_code) * 0.01 + + # TODO(pu): + if self._cfg.use_ture_chance_label_in_chance_encoder: + afterstate_policy_loss += cross_entropy_loss(afterstate_policy_logits, true_chance_one_hot.detach()) + # The encoder is not used i the mcts, so we don't need to calculate the commitment loss. + commitment_loss += torch.nn.MSELoss()(chance_encoding, true_chance_one_hot.float().detach()) + else: + afterstate_policy_loss += cross_entropy_loss(afterstate_policy_logits, chance_one_hot.detach()) + commitment_loss += torch.nn.MSELoss()(chance_encoding, chance_one_hot.float().detach()) afterstate_value_loss += cross_entropy_loss(afterstate_value, target_value_categorical[:, step_i]) value_loss += cross_entropy_loss(value, target_value_categorical[:, step_i + 1]) diff --git a/lzero/worker/muzero_collector.py b/lzero/worker/muzero_collector.py index 834937ce7..54289da09 100644 --- a/lzero/worker/muzero_collector.py +++ b/lzero/worker/muzero_collector.py @@ -232,7 +232,8 @@ def pad_and_save_last_trajectory(self, i, last_game_segments, last_game_prioriti end_index = beg_index + self.unroll_plus_td_steps - 1 pad_reward_lst = game_segments[i].reward_segment[beg_index:end_index] - chance_lst = game_segments[i].chance_segment[beg_index:end_index] + if self.policy_config.use_ture_chance_label_in_chance_encoder: + chance_lst = game_segments[i].chance_segment[beg_index:end_index] beg_index = 0 end_index = beg_index + self.unroll_plus_td_steps @@ -246,8 +247,10 @@ def pad_and_save_last_trajectory(self, i, last_game_segments, last_game_prioriti if self.policy_config.gumbel_algo: last_game_segments[i].pad_over(pad_obs_lst, pad_reward_lst, pad_root_values_lst, pad_child_visits_lst, next_segment_improved_policy = pad_improved_policy_prob) else: - #last_game_segments[i].pad_over(pad_obs_lst, pad_reward_lst, pad_root_values_lst, pad_child_visits_lst) - last_game_segments[i].pad_over(pad_obs_lst, pad_reward_lst, pad_root_values_lst, pad_child_visits_lst, chance_lst) + if self.policy_config.use_ture_chance_label_in_chance_encoder: + last_game_segments[i].pad_over(pad_obs_lst, pad_reward_lst, pad_root_values_lst, pad_child_visits_lst, next_chances = chance_lst) + else: + last_game_segments[i].pad_over(pad_obs_lst, pad_reward_lst, pad_root_values_lst, pad_child_visits_lst) """ Note: game_segment element shape: @@ -317,7 +320,8 @@ def collect(self, action_mask_dict = {i: to_ndarray(init_obs[i]['action_mask']) for i in range(env_nums)} to_play_dict = {i: to_ndarray(init_obs[i]['to_play']) for i in range(env_nums)} - chance_dict = {i: to_ndarray(init_obs[i]['chance']) for i in range(env_nums)} + if self.policy_config.use_ture_chance_label_in_chance_encoder: + chance_dict = {i: to_ndarray(init_obs[i]['chance']) for i in range(env_nums)} game_segments = [ GameSegment( @@ -371,10 +375,11 @@ def collect(self, action_mask_dict = {env_id: action_mask_dict[env_id] for env_id in ready_env_id} to_play_dict = {env_id: to_play_dict[env_id] for env_id in ready_env_id} - chance_dict = {env_id: chance_dict[env_id] for env_id in ready_env_id} action_mask = [action_mask_dict[env_id] for env_id in ready_env_id] to_play = [to_play_dict[env_id] for env_id in ready_env_id] - chance = [chance_dict[env_id] for env_id in ready_env_id] + if self.policy_config.use_ture_chance_label_in_chance_encoder: + chance_dict = {env_id: chance_dict[env_id] for env_id in ready_env_id} + chance = [chance_dict[env_id] for env_id in ready_env_id] stack_obs = to_ndarray(stack_obs) @@ -459,16 +464,23 @@ def collect(self, game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id]) # append a transition tuple, including a_t, o_{t+1}, r_{t}, action_mask_{t}, to_play_{t} # in ``game_segments[env_id].init``, we have append o_{t} in ``self.obs_segment`` - game_segments[env_id].append( - actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id], - to_play_dict[env_id], chance_dict[env_id] - ) + if self.policy_config.use_ture_chance_label_in_chance_encoder: + game_segments[env_id].append( + actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id], + to_play_dict[env_id], chance_dict[env_id] + ) + else: + game_segments[env_id].append( + actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id], + to_play_dict[env_id] + ) # NOTE: the position of code snippet is very important. # the obs['action_mask'] and obs['to_play'] are corresponding to the next action action_mask_dict[env_id] = to_ndarray(obs['action_mask']) to_play_dict[env_id] = to_ndarray(obs['to_play']) - chance_dict[env_id] = to_ndarray(obs['chance']) + if self.policy_config.use_ture_chance_label_in_chance_encoder: + chance_dict[env_id] = to_ndarray(obs['chance']) if self.policy_config.ignore_done: dones[env_id] = False @@ -592,7 +604,8 @@ def collect(self, action_mask_dict[env_id] = to_ndarray(init_obs[env_id]['action_mask']) to_play_dict[env_id] = to_ndarray(init_obs[env_id]['to_play']) - chance_dict[env_id] = to_ndarray(init_obs[env_id]['chance']) + if self.policy_config.use_ture_chance_label_in_chance_encoder: + chance_dict[env_id] = to_ndarray(init_obs[env_id]['chance']) game_segments[env_id] = GameSegment( self._env.action_space, diff --git a/zoo/game_2048/config/stochastic_muzero_2048_true_chance_config.py b/zoo/game_2048/config/stochastic_muzero_2048_true_chance_config.py index ec44696c1..eb54c865b 100644 --- a/zoo/game_2048/config/stochastic_muzero_2048_true_chance_config.py +++ b/zoo/game_2048/config/stochastic_muzero_2048_true_chance_config.py @@ -2,21 +2,22 @@ env_name = 'game_2048' action_space_size = 4 -chance_space_size= 16 +use_ture_chance_label_in_chance_encoder = True +chance_space_size = 32 # ============================================================== # begin of the most frequently changed config specified by the user # ============================================================== collector_env_num = 8 n_episode = 8 evaluator_env_num = 3 -num_simulations = 50 # TODO(pu):100 +num_simulations = 100 # TODO(pu): 50 update_per_collect = 200 batch_size = 512 max_env_step = int(1e8) reanalyze_ratio = 0. # collector_env_num = 1 -# n_episode = 2 +# n_episode = 1 # evaluator_env_num = 1 # num_simulations = 5 # update_per_collect = 3 @@ -28,12 +29,13 @@ # ============================================================== game_2048_stochastic_muzero_config = dict( - exp_name=f'july10_data_stomz_ctree/game_2048_stochastic_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_bs{batch_size}_sslw2_rew-morm-false_seed0', + exp_name=f'data_stochastic_mz_ctree/game_2048_stochastic_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_bs{batch_size}_chance-{use_ture_chance_label_in_chance_encoder}-{chance_space_size}_seed0', env=dict( stop_value=int(1e6), env_name=env_name, obs_shape=(16, 4, 4), obs_type='dict_observation', + raw_reward_type='raw', # 'merged_tiles_plus_log_max_tile_num' reward_normalize=False, reward_scale=100, max_tile=int(2**16), # 2**11=2048, 2**16=65536 @@ -53,6 +55,7 @@ discrete_action_encoding_type='one_hot', norm_type='BN', ), + use_ture_chance_label_in_chance_encoder=use_ture_chance_label_in_chance_encoder, mcts_ctree=True, gumbel_algo=False, cuda=True, @@ -60,7 +63,7 @@ game_segment_length=200, update_per_collect=update_per_collect, batch_size=batch_size, - td_steps=10, + td_steps=6, discount_factor=0.999, manual_temperature_decay=True, # optim_type='SGD', @@ -71,7 +74,7 @@ learning_rate=0.003, num_simulations=num_simulations, reanalyze_ratio=reanalyze_ratio, - ssl_loss_weight=0.1, # default is 0 + ssl_loss_weight=2, # default is 0 n_episode=n_episode, eval_freq=int(2e3), replay_buffer_size=int(2e7), # the size/capacity of replay_buffer, in the terms of transitions. From b22dea7be0fc278f678adb79e51e903b0eca0002 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= <2402552459@qq.com> Date: Fri, 4 Aug 2023 16:25:47 +0800 Subject: [PATCH 13/28] sync code --- lzero/policy/muzero.py | 10 ++++++-- lzero/policy/stochastic_muzero.py | 8 +++++-- zoo/game_2048/config/muzero_2048_config.py | 11 +++++---- ...ochastic_muzero_2048_true_chance_config.py | 24 +++++++++++++------ 4 files changed, 38 insertions(+), 15 deletions(-) diff --git a/lzero/policy/muzero.py b/lzero/policy/muzero.py index dfc5310ff..44a328eb1 100644 --- a/lzero/policy/muzero.py +++ b/lzero/policy/muzero.py @@ -148,6 +148,8 @@ class MuZeroPolicy(Policy): # (float) The fixed temperature value for MCTS action selection, which is used to control the exploration. # The larger the value, the more exploration. This value is only used when manual_temperature_decay=False. fixed_temperature_value=0.25, + # (bool) Whether to use the true chance in MCTS. + use_ture_chance_label_in_chance_encoder=False, # ****** Priority ****** # (bool) Whether to use priority when sampling training data from the buffer. @@ -208,7 +210,7 @@ def _init_learn(self) -> None: Overview: Learn mode init method. Called by ``self.__init__``. Initialize the learn model, optimizer and MCTS utils. """ - assert self._cfg.optim_type in ['SGD', 'Adam', 'AdamW'], self._cfg.optim_type + assert self._cfg.optim_type in ['SGD', 'Adam', 'AdamW_official', 'AdamW_nanoGPT'], self._cfg.optim_type # NOTE: in board_gmaes, for fixed lr 0.003, 'Adam' is better than 'SGD'. if self._cfg.optim_type == 'SGD': self._optimizer = optim.SGD( @@ -221,7 +223,11 @@ def _init_learn(self) -> None: self._optimizer = optim.Adam( self._model.parameters(), lr=self._cfg.learning_rate, weight_decay=self._cfg.weight_decay ) - elif self._cfg.optim_type == 'AdamW': + elif self._cfg.optim_type == 'AdamW_official': + self._optimizer = optim.AdamW( + self._model.parameters(), lr=self._cfg.learning_rate, weight_decay=self._cfg.weight_decay + ) + elif self._cfg.optim_type == 'AdamW_nanoGPT': self._optimizer = configure_optimizers( model=self._model, weight_decay=self._cfg.weight_decay, diff --git a/lzero/policy/stochastic_muzero.py b/lzero/policy/stochastic_muzero.py index 16cd70ef2..ee8f318df 100644 --- a/lzero/policy/stochastic_muzero.py +++ b/lzero/policy/stochastic_muzero.py @@ -209,7 +209,7 @@ def _init_learn(self) -> None: Overview: Learn mode init method. Called by ``self.__init__``. Ininitialize the learn model, optimizer and MCTS utils. """ - assert self._cfg.optim_type in ['SGD', 'Adam', 'AdamW'], self._cfg.optim_type + assert self._cfg.optim_type in ['SGD', 'Adam', 'AdamW_official', 'AdamW_nanoGPT'], self._cfg.optim_type # NOTE: in board_gmaes, for fixed lr 0.003, 'Adam' is better than 'SGD'. if self._cfg.optim_type == 'SGD': self._optimizer = optim.SGD( @@ -222,7 +222,11 @@ def _init_learn(self) -> None: self._optimizer = optim.Adam( self._model.parameters(), lr=self._cfg.learning_rate, weight_decay=self._cfg.weight_decay ) - elif self._cfg.optim_type == 'AdamW': + elif self._cfg.optim_type == 'AdamW_official': + self._optimizer = optim.AdamW( + self._model.parameters(), lr=self._cfg.learning_rate, weight_decay=self._cfg.weight_decay + ) + elif self._cfg.optim_type == 'AdamW_nanoGPT': self._optimizer = configure_optimizers( model=self._model, weight_decay=self._cfg.weight_decay, diff --git a/zoo/game_2048/config/muzero_2048_config.py b/zoo/game_2048/config/muzero_2048_config.py index 3f087bdc1..5563e7cab 100644 --- a/zoo/game_2048/config/muzero_2048_config.py +++ b/zoo/game_2048/config/muzero_2048_config.py @@ -13,7 +13,8 @@ num_simulations = 50 update_per_collect = 200 batch_size = 512 -max_env_step = int(5e6) +# max_env_step = int(5e6) +max_env_step = int(1e6) reanalyze_ratio = 0. # collector_env_num = 1 @@ -29,7 +30,7 @@ # ============================================================== atari_muzero_config = dict( - exp_name=f'data_mz_ctree/game_2048_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_bs{batch_size}_seed0', + exp_name=f'data_mz_ctree/game_2048_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_bs{batch_size}_adamw-wd1e-6_seed0', env=dict( stop_value=int(1e6), env_name=env_name, @@ -67,13 +68,15 @@ # lr_piecewise_constant_decay=True, # learning_rate=0.2, # init lr for manually decay schedule - optim_type='Adam', + optim_type='AdamW_nanoGPT', lr_piecewise_constant_decay=False, learning_rate=3e-3, + # (float) Weight decay for training policy network. + weight_decay=1e-6, num_simulations=num_simulations, reanalyze_ratio=reanalyze_ratio, - ssl_loss_weight=2, # default is 0 + ssl_loss_weight=0, # default is 0 n_episode=n_episode, eval_freq=int(2e3), replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. diff --git a/zoo/game_2048/config/stochastic_muzero_2048_true_chance_config.py b/zoo/game_2048/config/stochastic_muzero_2048_true_chance_config.py index eb54c865b..c2177cf3f 100644 --- a/zoo/game_2048/config/stochastic_muzero_2048_true_chance_config.py +++ b/zoo/game_2048/config/stochastic_muzero_2048_true_chance_config.py @@ -1,8 +1,14 @@ from easydict import EasyDict +import sys +sys.path.append('/mnt/nfs/puyuan/LightZero/zoo/game_2048') + +# export PYTHONPATH='/mnt/nfs/puyuan/LightZero/zoo/game_2048':$PYTHONPATH env_name = 'game_2048' action_space_size = 4 -use_ture_chance_label_in_chance_encoder = True +# use_ture_chance_label_in_chance_encoder = True +use_ture_chance_label_in_chance_encoder = False + chance_space_size = 32 # ============================================================== # begin of the most frequently changed config specified by the user @@ -10,10 +16,12 @@ collector_env_num = 8 n_episode = 8 evaluator_env_num = 3 -num_simulations = 100 # TODO(pu): 50 +# num_simulations = 100 # TODO(pu): 50 +num_simulations = 50 # TODO(pu): 50 + update_per_collect = 200 batch_size = 512 -max_env_step = int(1e8) +max_env_step = int(1e9) reanalyze_ratio = 0. # collector_env_num = 1 @@ -29,7 +37,7 @@ # ============================================================== game_2048_stochastic_muzero_config = dict( - exp_name=f'data_stochastic_mz_ctree/game_2048_stochastic_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_bs{batch_size}_chance-{use_ture_chance_label_in_chance_encoder}-{chance_space_size}_seed0', + exp_name=f'data_stochastic_mz_ctree/game_2048_stochastic_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_bs{batch_size}_chance-{use_ture_chance_label_in_chance_encoder}-{chance_space_size}_seed0_1e9', env=dict( stop_value=int(1e6), env_name=env_name, @@ -53,7 +61,7 @@ # NOTE: whether to use the self_supervised_learning_loss. default is False self_supervised_learning_loss=True, # default is False discrete_action_encoding_type='one_hot', - norm_type='BN', + norm_type='BN', ), use_ture_chance_label_in_chance_encoder=use_ture_chance_label_in_chance_encoder, mcts_ctree=True, @@ -63,7 +71,7 @@ game_segment_length=200, update_per_collect=update_per_collect, batch_size=batch_size, - td_steps=6, + td_steps=10, discount_factor=0.999, manual_temperature_decay=True, # optim_type='SGD', @@ -72,9 +80,11 @@ optim_type='Adam', lr_piecewise_constant_decay=False, learning_rate=0.003, + # learning_rate=0.0003, num_simulations=num_simulations, reanalyze_ratio=reanalyze_ratio, - ssl_loss_weight=2, # default is 0 + # ssl_loss_weight=2, # default is 0 + ssl_loss_weight=0, # default is 0 n_episode=n_episode, eval_freq=int(2e3), replay_buffer_size=int(2e7), # the size/capacity of replay_buffer, in the terms of transitions. From 860fda1d6073d5922b0d287f87e9a9d5a66da93d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= <2402552459@qq.com> Date: Wed, 9 Aug 2023 01:48:24 +0800 Subject: [PATCH 14/28] polish(pu): polish 2048 env, add env save_render_gif method, add 2048 env unittest, add stochatic muzero model unittest --- .../mcts/tree_search/mcts_ptree_stochastic.py | 35 +--- lzero/model/stochastic_muzero_model.py | 53 +++-- .../tests/test_stochastic_muzero_model.py | 72 +++++++ lzero/policy/stochastic_muzero.py | 12 +- .../config/atari_stochastic_muzero_config.py | 8 +- .../config/rule_based_2048_config.py | 67 ++++--- .../config/stochastic_muzero_2048_config.py | 54 +++--- ...ochastic_muzero_2048_true_chance_config.py | 118 ----------- zoo/game_2048/envs/game_2048_env.py | 183 +++++++++--------- zoo/game_2048/envs/test_game_2048_env.py | 70 +++++++ 10 files changed, 349 insertions(+), 323 deletions(-) create mode 100644 lzero/model/tests/test_stochastic_muzero_model.py delete mode 100644 zoo/game_2048/config/stochastic_muzero_2048_true_chance_config.py create mode 100644 zoo/game_2048/envs/test_game_2048_env.py diff --git a/lzero/mcts/tree_search/mcts_ptree_stochastic.py b/lzero/mcts/tree_search/mcts_ptree_stochastic.py index 23f2b35e1..8cbd41f7a 100644 --- a/lzero/mcts/tree_search/mcts_ptree_stochastic.py +++ b/lzero/mcts/tree_search/mcts_ptree_stochastic.py @@ -130,7 +130,8 @@ def search( # obtain the states for leaf nodes for ix, iy in zip(latent_state_index_in_search_path, latent_state_index_in_batch): latent_states.append( - latent_state_batch_in_search_path[ix][iy]) # latent_state_batch_in_search_path[ix][iy] shape e.g. (64,4,4) + latent_state_batch_in_search_path[ix][ + iy]) # latent_state_batch_in_search_path[ix][iy] shape e.g. (64,4,4) latent_states = torch.from_numpy(np.asarray(latent_states)).to(device).float() # only for discrete action @@ -143,7 +144,6 @@ def search( At the end of the simulation, the statistics along the trajectory are updated. """ # network_output = model.recurrent_inference(latent_states, last_actions) - num = len(leaf_nodes) latent_state_batch = [None] * num value_batch = [None] * num @@ -196,49 +196,32 @@ def process_nodes(node_indices, is_chance): process_nodes(chance_nodes, True) process_nodes(decision_nodes, False) - # latent_state_batch_chance = [latent_state_batch[leaf_idx] for leaf_idx in chance_nodes] - # latent_state_batch_decision = [latent_state_batch[leaf_idx] for leaf_idx in decision_nodes] value_batch_chance = [value_batch[leaf_idx] for leaf_idx in chance_nodes] value_batch_decision = [value_batch[leaf_idx] for leaf_idx in decision_nodes] reward_batch_chance = [reward_batch[leaf_idx] for leaf_idx in chance_nodes] reward_batch_decision = [reward_batch[leaf_idx] for leaf_idx in decision_nodes] policy_logits_batch_chance = [policy_logits_batch[leaf_idx] for leaf_idx in chance_nodes] - policy_logits_batch_decision = [policy_logits_batch[leaf_idx] for leaf_idx in decision_nodes] + policy_logits_batch_decision = [policy_logits_batch[leaf_idx] for leaf_idx in decision_nodes] latent_state_batch = np.concatenate(latent_state_batch, axis=0) latent_state_batch_in_search_path.append(latent_state_batch) current_latent_state_index = simulation_index + 1 - if(len(chance_nodes) > 0): + if len(chance_nodes) > 0: value_batch_chance = np.concatenate(value_batch_chance, axis=0) reward_batch_chance = np.concatenate(reward_batch_chance, axis=0) policy_logits_batch_chance = np.concatenate(policy_logits_batch_chance, axis=0) tree_muzero.batch_backpropagate( - current_latent_state_index, discount_factor, reward_batch_chance, value_batch_chance, policy_logits_batch_chance, + current_latent_state_index, discount_factor, reward_batch_chance, value_batch_chance, + policy_logits_batch_chance, min_max_stats_lst, results, virtual_to_play, child_is_chance_batch, chance_nodes ) - if(len(decision_nodes)>0): + if len(decision_nodes) > 0: value_batch_decision = np.concatenate(value_batch_decision, axis=0) reward_batch_decision = np.concatenate(reward_batch_decision, axis=0) policy_logits_batch_decision = np.concatenate(policy_logits_batch_decision, axis=0) tree_muzero.batch_backpropagate( - current_latent_state_index, discount_factor, reward_batch_decision, value_batch_decision, policy_logits_batch_decision, + current_latent_state_index, discount_factor, reward_batch_decision, value_batch_decision, + policy_logits_batch_decision, min_max_stats_lst, results, virtual_to_play, child_is_chance_batch, decision_nodes ) - - # latent_state_batch = np.concatenate(latent_state_batch, axis=0) - # value_batch = np.concatenate(value_batch, axis=0) - # reward_batch = np.concatenate(reward_batch, axis=0) - # policy_logits_batch = np.concatenate(policy_logits_batch, axis=0) - # latent_state_batch_in_search_path.append(latent_state_batch) - - # In ``batch_backpropagate()``, we first expand the leaf node using ``the policy_logits`` and - # ``reward`` predicted by the model, then perform backpropagation along the search path to update the - # statistics. - - # NOTE: simulation_index + 1 is very important, which is the depth of the current leaf node. - # current_latent_state_index = simulation_index + 1 - # tree_muzero.batch_backpropagate( - # current_latent_state_index, discount_factor, reward_batch, value_batch, policy_logits_batch, - # min_max_stats_lst, results, virtual_to_play, child_is_chance_batch - # ) diff --git a/lzero/model/stochastic_muzero_model.py b/lzero/model/stochastic_muzero_model.py index a40f027e1..17d24d06c 100644 --- a/lzero/model/stochastic_muzero_model.py +++ b/lzero/model/stochastic_muzero_model.py @@ -48,8 +48,8 @@ def __init__( ): """ Overview: - The definition of the neural network model used in MuZero. - MuZero model which consists of a representation network, a dynamics network and a prediction network. + The definition of the neural network model used in Stochastic MuZero. + Stochastic MuZero model which consists of a representation network, a dynamics network and a prediction network. The networks are build on convolution residual blocks and fully connected layers. Arguments: - observation_shape (:obj:`SequenceType`): Observation space shape, e.g. [C, W, H]=[12, 96, 96] for Atari. @@ -70,7 +70,7 @@ def __init__( - pred_hid (:obj:`int`): The size of prediction hidden layer. - pred_out (:obj:`int`): The size of prediction output layer. - self_supervised_learning_loss (:obj:`bool`): Whether to use self_supervised_learning related networks \ - in MuZero model, default set it to False. + in Stochastic MuZero model, default set it to False. - categorical_distribution (:obj:`bool`): Whether to use discrete support to represent categorical \ distribution for value and reward. - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ @@ -203,7 +203,7 @@ def __init__( def initial_inference(self, obs: torch.Tensor) -> MZNetworkOutput: """ Overview: - Initial inference of MuZero model, which is the first step of the MuZero model. + Initial inference of Stochastic MuZero model, which is the first step of the Stochastic MuZero model. To perform the initial inference, we first use the representation network to obtain the ``latent_state``. Then we use the prediction network to predict ``value`` and ``policy_logits`` of the ``latent_state``. Arguments: @@ -236,7 +236,7 @@ def recurrent_inference(self, state: torch.Tensor, option: torch.Tensor, afterstate: bool = False) -> MZNetworkOutput: """ Overview: - Recurrent inference of MuZero model, which is the rollout step of the MuZero model. + Recurrent inference of Stochastic MuZero model, which is the rollout step of the Stochastic MuZero model. To perform the recurrent inference, we first use the dynamics network to predict ``next_latent_state``, ``reward``, by the given current ``latent_state`` and ``action``. We then use the prediction network to predict the ``value`` and ``policy_logits`` of the current @@ -441,7 +441,7 @@ def project(self, latent_state: torch.Tensor, with_grad: bool = True) -> torch.T """ Overview: Project the latent state to a lower dimension to calculate the self-supervised loss, which is involved in - MuZero algorithm in EfficientZero. + in EfficientZero. For more details, please refer to paper ``Exploring Simple Siamese Representation Learning``. Arguments: - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. @@ -495,7 +495,7 @@ def __init__( ): """ Overview: - The definition of dynamics network in MuZero algorithm, which is used to predict next latent state and + The definition of dynamics network in Stochastic MuZero algorithm, which is used to predict next latent state and reward given current latent state and action. Arguments: - num_res_blocks (:obj:`int`): The number of res blocks in AlphaZero model. @@ -596,8 +596,7 @@ def __init__( ): """ Overview: - The definition of dynamics network in MuZero algorithm, which is used to predict next latent state and - reward given current latent state and action. + The definition of afterstate dynamics network in Stochastic MuZero algorithm, which is used to predict next afterstate given current latent state and action. Arguments: - num_res_blocks (:obj:`int`): The number of res blocks in AlphaZero model. - num_channels (:obj:`int`): The channels of input, including obs and action encoding. @@ -650,21 +649,21 @@ def forward(self, state_action_encoding: torch.Tensor) -> Tuple[torch.Tensor, to height, width). - reward (:obj:`torch.Tensor`): The predicted reward, with shape (batch_size, output_support_size). """ - # take the state encoding (latent_state), state_action_encoding[:, -1, :, :] is action encoding - latent_state = state_action_encoding[:, :-1, :, :] + # take the state encoding (afterstate), state_action_encoding[:, -1, :, :] is action encoding + afterstate = state_action_encoding[:, :-1, :, :] x = self.conv(state_action_encoding) x = self.bn(x) # the residual link: add state encoding to the state_action encoding - x += latent_state + x += afterstate x = self.activation(x) for block in self.resblocks: x = block(x) - afterstate_latent_state = x + afterstate = x # reward = None - x = self.conv1x1_reward(afterstate_latent_state) + x = self.conv1x1_reward(afterstate) x = self.bn_reward(x) x = self.activation(x) x = x.view(-1, self.flatten_output_size_for_reward_head) @@ -672,7 +671,7 @@ def forward(self, state_action_encoding: torch.Tensor) -> Tuple[torch.Tensor, to # use the fully connected layer to predict reward reward = self.fc_reward_head(x) - return afterstate_latent_state, reward + return afterstate, reward def get_dynamic_mean(self) -> float: return get_dynamic_mean(self) @@ -699,8 +698,8 @@ def __init__( ) -> None: """ Overview: - The definition of policy and value prediction network, which is used to predict value and policy by the - given latent state. + The definition of afterstate policy and value prediction network, which is used to predict value and policy by the + given afterstate. Arguments: - action_space_size: (:obj:`int`): Action space size, usually an integer number for discrete action space. - num_res_blocks (:obj:`int`): The number of res blocks in AlphaZero model. @@ -758,33 +757,33 @@ def __init__( last_linear_layer_init_zero=last_linear_layer_init_zero ) - def forward(self, latent_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def forward(self, afterstate: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Overview: Forward computation of the prediction network. Arguments: - - latent_state (:obj:`torch.Tensor`): input tensor with shape (B, latent_state_dim). + - afterstate (:obj:`torch.Tensor`): input tensor with shape (B, afterstate_dim). Returns: - - policy (:obj:`torch.Tensor`): policy tensor with shape (B, action_space_size). - - value (:obj:`torch.Tensor`): value tensor with shape (B, output_support_size). + - afterstate_policy_logits (:obj:`torch.Tensor`): policy tensor with shape (B, action_space_size). + - afterstate_value (:obj:`torch.Tensor`): value tensor with shape (B, output_support_size). """ for res_block in self.resblocks: - latent_state = res_block(latent_state) + afterstate = res_block(afterstate) - value = self.conv1x1_value(latent_state) + value = self.conv1x1_value(afterstate) value = self.bn_value(value) value = self.activation(value) - policy = self.conv1x1_policy(latent_state) + policy = self.conv1x1_policy(afterstate) policy = self.bn_policy(policy) policy = self.activation(policy) value = value.reshape(-1, self.flatten_output_size_for_value_head) policy = policy.reshape(-1, self.flatten_output_size_for_policy_head) - value = self.fc_value(value) - policy = self.fc_policy(policy) - return policy, value + afterstate_value = self.fc_value(value) + afterstate_policy_logits = self.fc_policy(policy) + return afterstate_policy_logits, afterstate_value class ChanceEncoderBackbone(nn.Module): diff --git a/lzero/model/tests/test_stochastic_muzero_model.py b/lzero/model/tests/test_stochastic_muzero_model.py new file mode 100644 index 000000000..e53b86351 --- /dev/null +++ b/lzero/model/tests/test_stochastic_muzero_model.py @@ -0,0 +1,72 @@ +import torch +import pytest +from torch import nn +from lzero.model.stochastic_muzero_model import ChanceEncoder + +# Initialize a ChanceEncoder instance for testing +@pytest.fixture +def encoder(): + return ChanceEncoder((3, 32, 32), 4) + +def test_ChanceEncoder(encoder): + # Create a dummy tensor for testing + x_and_last_x = torch.randn(1, 6, 32, 32) + + # Forward pass through the encoder + chance_encoding_t, chance_onehot_t = encoder(x_and_last_x) + + # Check the output shapes + assert chance_encoding_t.shape == (1, 4) + assert chance_onehot_t.shape == (1, 4) + + # Check that chance_onehot_t is indeed one-hot + assert torch.all((chance_onehot_t == 0) | (chance_onehot_t == 1)) + assert torch.all(torch.sum(chance_onehot_t, dim=1) == 1) + +def test_ChanceEncoder_gradients_chance_encoding(encoder): + # Create a dummy tensor for testing + x_and_last_x = torch.randn(1, 6, 32, 32) + + # Forward pass through the encoder + chance_encoding_t, chance_onehot_t = encoder(x_and_last_x) + + # Create a dummy target tensor for a simple loss function + target = torch.randn(1, 4) + + # Use mean squared error as a simple loss function + loss = nn.MSELoss()(chance_encoding_t, target) + + # Backward pass + loss.backward() + + # Check if gradients are computed + for param in encoder.parameters(): + assert param.grad is not None + + # Check if gradients have the correct shape + for param in encoder.parameters(): + assert param.grad.shape == param.shape + +def test_ChanceEncoder_gradients_chance_onehot_t(encoder): + # Create a dummy tensor for testing + x_and_last_x = torch.randn(1, 6, 32, 32) + + # Forward pass through the encoder + chance_encoding_t, chance_onehot_t = encoder(x_and_last_x) + + # Create a dummy target tensor for a simple loss function + target = torch.randn(1, 4) + + # Use mean squared error as a simple loss function + loss = nn.MSELoss()(chance_onehot_t, target) + + # Backward pass + loss.backward() + + # Check if gradients are computed + for param in encoder.parameters(): + assert param.grad is not None + + # Check if gradients have the correct shape + for param in encoder.parameters(): + assert param.grad.shape == param.shape diff --git a/lzero/policy/stochastic_muzero.py b/lzero/policy/stochastic_muzero.py index ee8f318df..65bc6fbff 100644 --- a/lzero/policy/stochastic_muzero.py +++ b/lzero/policy/stochastic_muzero.py @@ -372,14 +372,14 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # the core recurrent_inference in MuZero policy. # ============================================================== for step_i in range(self._cfg.num_unroll_steps): - # unroll with the afterstate dynamic function: predict 'after_state', + # unroll with the afterstate dynamic function: predict 'afterstate', # given current ``state`` and ``action``. # 'afterstate reward' is not used, we kept it for the sake of uniformity between decision nodes and chance nodes. - # And then predict afterstate policy_logits and afterstate value with the afterstate prediction function. + # And then predict afterstate_policy_logits and afterstate_value with the afterstate prediction function. network_output = self._learn_model.recurrent_inference( latent_state, action_batch[:, step_i], afterstate=False ) - after_state, afterstate_reward, afterstate_value, afterstate_policy_logits = mz_network_output_unpack(network_output) + afterstate, afterstate_reward, afterstate_value, afterstate_policy_logits = mz_network_output_unpack(network_output) # concat consecutive frames to predict chance former_frame = obs_list_for_chance_encoder[step_i] @@ -394,9 +394,9 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in chance_code = torch.argmax(chance_encoding, dim=1).long().unsqueeze(-1) # unroll with the dynamics function: predict the next ``latent_state``, ``reward``, - # given current ``after_state`` and ``chance_long``. + # given current ``afterstate`` and ``chance_code``. # And then predict policy_logits and value with the prediction function. - network_output = self._learn_model.recurrent_inference(after_state, chance_code, afterstate=True) + network_output = self._learn_model.recurrent_inference(afterstate, chance_code, afterstate=True) latent_state, reward, value, policy_logits = mz_network_output_unpack(network_output) @@ -762,8 +762,8 @@ def _monitor_vars_learn(self) -> List[str]: 'value_loss', 'consistency_loss', 'afterstate_policy_loss', - 'commitment_loss', 'afterstate_value_loss', + 'commitment_loss', 'value_priority', 'target_reward', 'target_value', diff --git a/zoo/atari/config/atari_stochastic_muzero_config.py b/zoo/atari/config/atari_stochastic_muzero_config.py index 8083d579d..3854b859c 100644 --- a/zoo/atari/config/atari_stochastic_muzero_config.py +++ b/zoo/atari/config/atari_stochastic_muzero_config.py @@ -25,7 +25,7 @@ batch_size = 256 max_env_step = int(1e6) reanalyze_ratio = 0. -chance_space_size = 2 +chance_space_size = 4 # ============================================================== # end of the most frequently changed config specified by the user # ============================================================== @@ -61,9 +61,9 @@ use_augmentation=True, update_per_collect=update_per_collect, batch_size=batch_size, - optim_type='SGD', - lr_piecewise_constant_decay=True, - learning_rate=0.2, + optim_type='Adam', + lr_piecewise_constant_decay=False, + learning_rate=3e-3, num_simulations=num_simulations, reanalyze_ratio=reanalyze_ratio, ssl_loss_weight=2, # default is 0 diff --git a/zoo/game_2048/config/rule_based_2048_config.py b/zoo/game_2048/config/rule_based_2048_config.py index 17c24d79e..e198b157a 100644 --- a/zoo/game_2048/config/rule_based_2048_config.py +++ b/zoo/game_2048/config/rule_based_2048_config.py @@ -1,26 +1,33 @@ +from functools import lru_cache +from typing import Tuple, Union + import numpy as np -from zoo.game_2048.envs.game_2048_env import Game2048Env, IllegalMove -import pytest from easydict import EasyDict - -from typing import Tuple, Union from rich import print -from functools import lru_cache -import time -import numpy as np + +from zoo.game_2048.envs.game_2048_env import Game2048Env +# Define rule-based search function def rule_based_search(grid: np.array, fast_search: bool = True) -> int: + """ + Overview: + Use Expectimax search algorithm to find the best action. + Adapted from https://github.com/xwjdsh/2048-ai/blob/master/ai/ai.go. + """ + # please refer to https://codemyroad.wordpress.com/2014/05/14/2048-ai-the-intelligent-bot/ model1 = np.array([[16, 15, 14, 13], [9, 10, 11, 12], [8, 7, 6, 5], [1, 2, 2, 4]]) model2 = np.array([[16, 15, 12, 4], [14, 13, 11, 3], [10, 9, 8, 2], [7, 6, 5, 1]]) model3 = np.array([[16, 15, 14, 4], [13, 12, 11, 3], [10, 9, 8, 2], [7, 6, 5, 1]]) + # Use lru_cache decorator for caching, speeding up subsequent look-ups @lru_cache(maxsize=512) def get_model_score(value, i, j): result = np.zeros(3 * 8) for k, m in enumerate([model1, model2, model3]): start = k * 8 result[start] += m[i, j] * value + # Scores of other 7 directions of the model result[start + 1] += m[i, 3 - j] * value result[start + 2] += m[j, i] * value result[start + 3] += m[3 - j, i] * value @@ -31,26 +38,18 @@ def get_model_score(value, i, j): return result def get_score(grid: np.array) -> float: + # Calculate the score of the current layout result = np.zeros(3 * 8) for i in range(4): for j in range(4): if grid[i, j] != 0: result += get_model_score(grid[i, j], i, j) - # for k, m in enumerate([model1, model2, model3]): - # start = k * 8 - # value = grid[i, j] # whether use log2 here - # result[start] += m[i, j] * value - # result[start + 1] += m[i, 3 - j] * value - # result[start + 2] += m[j, i] * value - # result[start + 3] += m[3 - j, i] * value - # result[start + 4] += m[3 - i, 3 - j] * value - # result[start + 5] += m[3 - i, j] * value - # result[start + 6] += m[j, 3 - i] * value - # result[start + 7] += m[3 - j, 3 - i] * value return result.max() def expectation_search(grid: np.array, depth: int, chance_node: bool) -> Tuple[float, Union[int, None]]: + # Use Expectimax search algorithm to find the best action + # please refer to https://courses.cs.washington.edu/courses/cse473/11au/slides/cse473au11-adversarial-search.pdf if depth == 0: return get_score(grid), None if chance_node: @@ -83,7 +82,7 @@ def expectation_search(grid: np.array, depth: int, chance_node: bool) -> Tuple[f best_action = dire return best_score, best_action - # depth selection + # Select search depth based on the current maximum tile value grid_max = grid.max() if grid_max >= 2048: depth = 6 @@ -91,11 +90,12 @@ def expectation_search(grid: np.array, depth: int, chance_node: bool) -> Tuple[f depth = 5 else: depth = 4 - # rule_based_search + # Call the expectation search algorithm and return the best action _, best_action = expectation_search(grid, depth, False) return best_action +# Define move function, implement move operation in 2048 game def move(grid: np.array, action: int, game_score: int = 0) -> Tuple[np.array, bool, int]: # execute action in 2048 game # 0, 1, 2, 3 mean top, right, bottom, left @@ -139,9 +139,8 @@ def move(grid: np.array, action: int, game_score: int = 0) -> Tuple[np.array, bo return grid, move_flag, game_score +# # Define generate function, randomly generate 2 or 4 in an empty location def generate(grid: np.array) -> np.array: - # random generate a new number in empty location - # 2 or 4 number = np.random.choice([2, 4], p=[0.9, 0.1]) # get empty location empty = np.where(grid == 0) @@ -152,9 +151,9 @@ def generate(grid: np.array) -> np.array: # return new grid return grid - +# Define game configuration config = EasyDict(dict( - env_name="game_2048_env_2048", + env_name="game_2048", save_replay_gif=False, replay_path_gif=None, replay_path=None, @@ -162,8 +161,7 @@ def generate(grid: np.array) -> np.array: channel_last=True, obs_type='array', raw_reward_type='raw', # 'merged_tiles_plus_log_max_tile_num' - # reward_type='merged_tiles_plus_log_max_tile_num', - # reward_normalize=True, + reward_type='raw', # 'merged_tiles_plus_log_max_tile_num' reward_normalize=False, reward_norm_scale=100, max_tile=int(2 ** 16), @@ -182,20 +180,27 @@ def generate(grid: np.array) -> np.array: game_2048_env.render() step = 0 while True: - # action = env.human_to_action() print('=' * 20) grid = obs.astype(np.int64) + # action = game_2048_env.human_to_action() action = game_2048_env.random_action() - action = rule_based_search(grid) + # action = rule_based_search(grid) if action == 1: action = 2 elif action == 2: action = 1 - obs, reward, done, info = game_2048_env.step(action) + try: + obs, reward, done, info = game_2048_env.step(action) + except Exception as e: + print(f'Exception: {e}') + print('total_step_number: {}'.format(step)) + game_2048_env.save_render_gif(gif_name_suffix='bot') + break step += 1 print(f"step: {step}, action: {action}, reward: {reward}, raw_reward: {info['raw_reward']}") - game_2048_env.render() - + game_2048_env.render(mode='human') + game_2048_env.render(mode='rgb_array_render') if done: print('total_step_number: {}'.format(step)) + game_2048_env.save_render_gif(gif_name_suffix='bot') break diff --git a/zoo/game_2048/config/stochastic_muzero_2048_config.py b/zoo/game_2048/config/stochastic_muzero_2048_config.py index a24026f82..69961a33e 100644 --- a/zoo/game_2048/config/stochastic_muzero_2048_config.py +++ b/zoo/game_2048/config/stochastic_muzero_2048_config.py @@ -1,36 +1,42 @@ from easydict import EasyDict +import sys +sys.path.append('/mnt/nfs/puyuan/LightZero/zoo/game_2048') +# export PYTHONPATH='/mnt/nfs/puyuan/LightZero/zoo/game_2048':$PYTHONPATH +env_name = 'game_2048' +action_space_size = 4 +# use_ture_chance_label_in_chance_encoder = True +use_ture_chance_label_in_chance_encoder = False +chance_space_size = 32 # ============================================================== # begin of the most frequently changed config specified by the user # ============================================================== -env_name = 'game_2048' -action_space_size = 4 -chance_space_size = 4 - collector_env_num = 8 n_episode = 8 evaluator_env_num = 3 -num_simulations = 50 +# num_simulations = 100 # TODO(pu): 50 +num_simulations = 50 # TODO(pu): 50 + update_per_collect = 200 batch_size = 512 -max_env_step = int(1e8) +max_env_step = int(1e9) reanalyze_ratio = 0. -collector_env_num = 1 -n_episode = 1 -evaluator_env_num = 1 -num_simulations = 5 -update_per_collect = 3 -batch_size = 5 -max_env_step = int(1e6) -reanalyze_ratio = 0. +# collector_env_num = 1 +# n_episode = 1 +# evaluator_env_num = 1 +# num_simulations = 5 +# update_per_collect = 3 +# batch_size = 5 +# max_env_step = int(1e6) +# reanalyze_ratio = 0. # ============================================================== # end of the most frequently changed config specified by the user # ============================================================== game_2048_stochastic_muzero_config = dict( - exp_name=f'data_stochastic_mz_ctree/game_2048_stochastic_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_bs{batch_size}_seed0', + exp_name=f'data_stochastic_mz_ctree/game_2048_stochastic_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_bs{batch_size}_chance-{use_ture_chance_label_in_chance_encoder}-{chance_space_size}_seed0_1e9', env=dict( stop_value=int(1e6), env_name=env_name, @@ -39,7 +45,7 @@ raw_reward_type='raw', # 'merged_tiles_plus_log_max_tile_num' reward_normalize=False, reward_scale=100, - max_tile=int(2 ** 16), # 2**11=2048, 2**16=65536 + max_tile=int(2**16), # 2**11=2048, 2**16=65536 collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, n_evaluator_episode=evaluator_env_num, @@ -54,33 +60,33 @@ # NOTE: whether to use the self_supervised_learning_loss. default is False self_supervised_learning_loss=True, # default is False discrete_action_encoding_type='one_hot', - norm_type='BN', + norm_type='BN', ), + use_ture_chance_label_in_chance_encoder=use_ture_chance_label_in_chance_encoder, mcts_ctree=True, gumbel_algo=False, cuda=True, env_type='not_board_games', - game_segment_length=400, + game_segment_length=200, update_per_collect=update_per_collect, batch_size=batch_size, td_steps=10, discount_factor=0.999, manual_temperature_decay=True, - threshold_training_steps_for_final_temperature=int(1e5), # optim_type='SGD', # lr_piecewise_constant_decay=True, # learning_rate=0.2, # init lr for manually decay schedule - optim_type='Adam', lr_piecewise_constant_decay=False, - learning_rate=3e-3, - + learning_rate=0.003, + # learning_rate=0.0003, num_simulations=num_simulations, reanalyze_ratio=reanalyze_ratio, - ssl_loss_weight=2, # default is 0 + # ssl_loss_weight=2, # default is 0 + ssl_loss_weight=0, # default is 0 n_episode=n_episode, eval_freq=int(2e3), - replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + replay_buffer_size=int(2e7), # the size/capacity of replay_buffer, in the terms of transitions. collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, ), diff --git a/zoo/game_2048/config/stochastic_muzero_2048_true_chance_config.py b/zoo/game_2048/config/stochastic_muzero_2048_true_chance_config.py deleted file mode 100644 index c2177cf3f..000000000 --- a/zoo/game_2048/config/stochastic_muzero_2048_true_chance_config.py +++ /dev/null @@ -1,118 +0,0 @@ -from easydict import EasyDict -import sys -sys.path.append('/mnt/nfs/puyuan/LightZero/zoo/game_2048') - -# export PYTHONPATH='/mnt/nfs/puyuan/LightZero/zoo/game_2048':$PYTHONPATH - -env_name = 'game_2048' -action_space_size = 4 -# use_ture_chance_label_in_chance_encoder = True -use_ture_chance_label_in_chance_encoder = False - -chance_space_size = 32 -# ============================================================== -# begin of the most frequently changed config specified by the user -# ============================================================== -collector_env_num = 8 -n_episode = 8 -evaluator_env_num = 3 -# num_simulations = 100 # TODO(pu): 50 -num_simulations = 50 # TODO(pu): 50 - -update_per_collect = 200 -batch_size = 512 -max_env_step = int(1e9) -reanalyze_ratio = 0. - -# collector_env_num = 1 -# n_episode = 1 -# evaluator_env_num = 1 -# num_simulations = 5 -# update_per_collect = 3 -# batch_size = 5 -# max_env_step = int(1e6) -# reanalyze_ratio = 0. -# ============================================================== -# end of the most frequently changed config specified by the user -# ============================================================== - -game_2048_stochastic_muzero_config = dict( - exp_name=f'data_stochastic_mz_ctree/game_2048_stochastic_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_bs{batch_size}_chance-{use_ture_chance_label_in_chance_encoder}-{chance_space_size}_seed0_1e9', - env=dict( - stop_value=int(1e6), - env_name=env_name, - obs_shape=(16, 4, 4), - obs_type='dict_observation', - raw_reward_type='raw', # 'merged_tiles_plus_log_max_tile_num' - reward_normalize=False, - reward_scale=100, - max_tile=int(2**16), # 2**11=2048, 2**16=65536 - collector_env_num=collector_env_num, - evaluator_env_num=evaluator_env_num, - n_evaluator_episode=evaluator_env_num, - manager=dict(shared_memory=False, ), - ), - policy=dict( - model=dict( - observation_shape=(16, 4, 4), - action_space_size=action_space_size, - chance_space_size=chance_space_size, - image_channel=16, - # NOTE: whether to use the self_supervised_learning_loss. default is False - self_supervised_learning_loss=True, # default is False - discrete_action_encoding_type='one_hot', - norm_type='BN', - ), - use_ture_chance_label_in_chance_encoder=use_ture_chance_label_in_chance_encoder, - mcts_ctree=True, - gumbel_algo=False, - cuda=True, - env_type='not_board_games', - game_segment_length=200, - update_per_collect=update_per_collect, - batch_size=batch_size, - td_steps=10, - discount_factor=0.999, - manual_temperature_decay=True, - # optim_type='SGD', - # lr_piecewise_constant_decay=True, - # learning_rate=0.2, # init lr for manually decay schedule - optim_type='Adam', - lr_piecewise_constant_decay=False, - learning_rate=0.003, - # learning_rate=0.0003, - num_simulations=num_simulations, - reanalyze_ratio=reanalyze_ratio, - # ssl_loss_weight=2, # default is 0 - ssl_loss_weight=0, # default is 0 - n_episode=n_episode, - eval_freq=int(2e3), - replay_buffer_size=int(2e7), # the size/capacity of replay_buffer, in the terms of transitions. - collector_env_num=collector_env_num, - evaluator_env_num=evaluator_env_num, - ), -) -game_2048_stochastic_muzero_config = EasyDict(game_2048_stochastic_muzero_config) -main_config = game_2048_stochastic_muzero_config - -game_2048_stochastic_muzero_create_config = dict( - env=dict( - type='game_2048', - import_names=['zoo.game_2048.envs.game_2048_env'], - ), - env_manager=dict(type='subprocess'), - policy=dict( - type='stochastic_muzero', - import_names=['lzero.policy.stochastic_muzero'], - ), - collector=dict( - type='episode_muzero', - import_names=['lzero.worker.muzero_collector'], - ) -) -game_2048_stochastic_muzero_create_config = EasyDict(game_2048_stochastic_muzero_create_config) -create_config = game_2048_stochastic_muzero_create_config - -if __name__ == "__main__": - from lzero.entry import train_muzero - train_muzero([main_config, create_config], seed=0, max_env_step=max_env_step) diff --git a/zoo/game_2048/envs/game_2048_env.py b/zoo/game_2048/envs/game_2048_env.py index 03b056ed0..a99e0a0b9 100644 --- a/zoo/game_2048/envs/game_2048_env.py +++ b/zoo/game_2048/envs/game_2048_env.py @@ -14,7 +14,8 @@ from gym import spaces from gym.utils import seeding from six import StringIO - +import matplotlib.pyplot as plt +import imageio @ENV_REGISTRY.register('game_2048') class Game2048Env(gym.Env): @@ -37,7 +38,7 @@ class Game2048Env(gym.Env): ignore_legal_actions=True, need_flatten=False, ) - metadata = {'render.modes': ['human', 'ansi', 'rgb_array']} + metadata = {'render.modes': ['human', 'rgb_array_render']} @classmethod def default_config(cls: type) -> EasyDict: @@ -59,8 +60,13 @@ def __init__(self, cfg: dict) -> None: self.reward_normalize = cfg.reward_normalize self.reward_norm_scale = cfg.reward_norm_scale assert self.reward_type in ['raw', 'merged_tiles_plus_log_max_tile_num'] - assert self.reward_type=='raw' or (self.reward_type=='merged_tiles_plus_log_max_tile_num' and self.reward_normalize==False) + assert self.reward_type == 'raw' or ( + self.reward_type == 'merged_tiles_plus_log_max_tile_num' and self.reward_normalize == False) self.max_tile = cfg.max_tile + # Define the maximum tile that will end the game (e.g. 2048). None means no limit. + # This does not affect the state returned. + assert self.max_tile is None or isinstance(self.max_tile, int) + self.max_episode_steps = cfg.max_episode_steps self.is_collect = cfg.is_collect self.ignore_legal_actions = cfg.ignore_legal_actions @@ -76,17 +82,14 @@ def __init__(self, cfg: dict) -> None: # Members for gym implementation: self._action_space = spaces.Discrete(4) self._observation_space = spaces.Box(0, 1, (self.w, self.h, self.squares), dtype=int) - - self.set_illegal_move_reward(0.) - self.set_max_tile(max_tile=self.max_tile) - self._reward_range = (0., self.max_tile) + self.set_illegal_move_reward(0.) # for render self.grid_size = 70 - # Initialise the random seed of the gym environment. self.seed() + self.frames = [] def reset(self): """Reset the game board-matrix and add 2 tiles.""" @@ -192,7 +195,7 @@ def step(self, action): reward = to_ndarray([reward_merged_tiles_plus_log_max_tile_num]).astype(np.float32) elif self.reward_type == 'raw': reward = to_ndarray([reward]).astype(np.float32) - info = {"raw_reward": raw_reward, "max_tile": self.highest(), 'highest': self.highest()} + info = {"raw_reward": raw_reward, "current_max_tile_num": self.highest()} if done: info['eval_episode_return'] = self._final_eval_reward @@ -233,25 +236,28 @@ def move(self, direction, trial=False): if dir_mod_two == 0: # Up or down, split into columns for y in range(self.h): - old = [self.get(x, y) for x in rx] + old = [self.board[x, y] for x in rx] (new, ms) = self.shift(old, shift_direction) move_reward += ms if old != new: changed = True if not trial: for x in rx: - self.set(x, y, new[x]) + self.board[x, y] = new[x] + else: # Left or right, split into rows for x in range(self.w): - old = [self.get(x, y) for y in ry] + old = [self.board[x, y] for y in ry] (new, ms) = self.shift(old, shift_direction) move_reward += ms if old != new: changed = True if not trial: for y in ry: - self.set(x, y, new[y]) + self.board[x, y] = new[y] + + # TODO(pu): different transition dynamics # if not changed: # raise IllegalMove @@ -266,33 +272,9 @@ def set_illegal_move_reward(self, reward): self.illegal_move_reward = reward self.reward_range = (self.illegal_move_reward, float(2 ** self.squares)) - def set_max_tile(self, max_tile: int = 2048): - """ - Define the maximum tile that will end the game (e.g. 2048). None means no limit. - This does not affect the state returned. - """ - assert max_tile is None or isinstance(max_tile, int) - self.max_tile = max_tile - def render(self, mode='human'): - if mode == 'rgb_array': - black = (0, 0, 0) + if mode == 'rgb_array_render': grey = (128, 128, 128) - white = (255, 255, 255) - tile_colour_map = { - 2: (255, 0, 0), - 4: (224, 32, 0), - 8: (192, 64, 0), - 16: (160, 96, 0), - 32: (128, 128, 0), - 64: (96, 160, 0), - 128: (64, 192, 0), - 256: (32, 224, 0), - 512: (0, 255, 0), - 1024: (0, 224, 32), - 2048: (0, 192, 64), - 4096: (0, 160, 96), - } grid_size = self.grid_size # Render with Pillow @@ -303,26 +285,59 @@ def render(self, mode='human'): for y in range(4): for x in range(4): - o = self.get(y, x) + o = self.board[y, x] if o: - draw.rectangle([x * grid_size, y * grid_size, (x + 1) * grid_size, (y + 1) * grid_size], - tile_colour_map[o]) - (text_x_size, text_y_size) = draw.textsize(str(o), font=fnt) - draw.text((x * grid_size + (grid_size - text_x_size) // 2, - y * grid_size + (grid_size - text_y_size) // 2), str(o), font=fnt, fill=white) - assert text_x_size < grid_size - assert text_y_size < grid_size - - return np.asarray(pil_board).swapaxes(0, 1) - - outfile = StringIO() if mode == 'ansi' else sys.stdout - s = 'Current Return: {}, '.format(self.episode_return) - s += 'Highest Tile: {}\n'.format(self.highest()) - npa = np.array(self.board) - grid = npa.reshape((self.size, self.size)) - s += "{}\n".format(grid) - outfile.write(s) - return outfile + self.draw_tile(draw, x, y, o, fnt) + + # Instead of returning the image, we display it using pyplot + plt.imshow(np.asarray(pil_board)) + plt.draw() + # plt.pause(0.001) + # Append the frame to frames for gif + self.frames.append(np.asarray(pil_board)) + elif mode == 'human': + s = 'Current Return: {}, '.format(self.episode_return) + s += 'Current Highest Tile number: {}\n'.format(self.highest()) + npa = np.array(self.board) + grid = npa.reshape((self.size, self.size)) + s += "{}\n".format(grid) + sys.stdout.write(s) + return sys.stdout + + def draw_tile(self, draw, x, y, o, fnt): + grid_size = self.grid_size + white = (255, 255, 255) + tile_colour_map = { + 0: (204, 192, 179), + 2: (238, 228, 218), + 4: (237, 224, 200), + 8: (242, 177, 121), + 16: (245, 149, 99), + 32: (246, 124, 95), + 64: (246, 94, 59), + 128: (237, 207, 114), + 256: (237, 204, 97), + 512: (237, 200, 80), + 1024: (237, 197, 63), + 2048: (237, 194, 46), + 4096: (237, 194, 46), + 8192: (237, 194, 46), + 16384:(237, 194, 46), + } + if o: + draw.rectangle([x * grid_size, y * grid_size, (x + 1) * grid_size, (y + 1) * grid_size], + tile_colour_map[o]) + bbox = draw.textbbox((x, y), str(o), font=fnt) + text_x_size, text_y_size = bbox[2] - bbox[0], bbox[3] - bbox[1] + draw.text((x * grid_size + (grid_size - text_x_size) // 2, + y * grid_size + (grid_size - text_y_size) // 2), str(o), font=fnt, fill=white) + assert text_x_size < grid_size + assert text_y_size < grid_size + + def save_render_gif(self, gif_name_suffix: str = ''): + # At the end of the episode, save the frames as a gif + imageio.mimsave(f'game_2048_{gif_name_suffix}.gif', self.frames) + self.frames = [] # Implementation of game logic for 2048 def add_random_2_4_tile(self): @@ -331,7 +346,6 @@ def add_random_2_4_tile(self): tile_probabilities = np.array([0.9, 0.1]) tile_val = self.np_random.choice(possible_tiles, 1, p=tile_probabilities)[0] empty_location = self.get_empty_location() - # assert empty_location.shape[0] if empty_location.shape[0] == 0: self.should_done = True return @@ -339,6 +353,7 @@ def add_random_2_4_tile(self): empty = empty_location[empty_idx] logging.debug("Adding %s at %s", tile_val, (empty[0], empty[1])) + # set the chance outcome if self.chance_space_size == 16: self.chance = 4 * empty[0] + empty[1] elif self.chance_space_size == 32: @@ -347,15 +362,7 @@ def add_random_2_4_tile(self): elif tile_val == 4: self.chance = 16 + 4 * empty[0] + empty[1] - self.set(empty[0], empty[1], tile_val) - - def get(self, x, y): - """Get the value of one square.""" - return self.board[x, y] - - def set(self, x, y, val): - """Set the value of one square.""" - self.board[x, y] = val + self.board[empty[0], empty[1]] = tile_val def get_empty_location(self): """Return a 2d numpy array with the location of empty squares.""" @@ -365,7 +372,6 @@ def highest(self): """Report the highest tile on the board.""" return np.max(self.board) - @property def legal_actions(self): """ @@ -394,7 +400,7 @@ def legal_actions(self): if dir_mod_two == 0: # Up or down, split into columns for y in range(self.h): - old = [self.get(x, y) for x in rx] + old = [self.board[x, y] for x in rx] (new, move_reward_tmp) = self.shift(old, shift_direction) move_reward += move_reward_tmp if old != new: @@ -402,7 +408,7 @@ def legal_actions(self): else: # Left or right, split into rows for x in range(self.w): - old = [self.get(x, y) for y in ry] + old = [self.board[x, y] for y in ry] (new, move_reward_tmp) = self.shift(old, shift_direction) move_reward += move_reward_tmp if old != new: @@ -414,8 +420,11 @@ def legal_actions(self): return legal_actions def combine(self, shifted_row): - """Combine same tiles when moving to one side. This function always - shifts towards the left. Also count the reward of combined tiles.""" + """ + Overview: + Combine same tiles when moving to one side. This function always + shifts towards the left. Also count the reward of combined tiles. + """ move_reward = 0 combined_row = [0] * self.size skip = False @@ -528,21 +537,6 @@ def __repr__(self) -> str: return "LightZero 2048 Env." -def pairwise(iterable): - """s -> (s0,s1), (s1,s2), (s2, s3), ...""" - a, b = itertools.tee(iterable) - next(b, None) - return zip(a, b) - - -class IllegalMove(Exception): - pass - - -class IllegalActionError(Exception): - pass - - def encode_board(flat_board, num_of_template_tiles=16): """ Overview: @@ -570,3 +564,18 @@ def encode_board(flat_board, num_of_template_tiles=16): one_hot_board = (layered_board == tile_values).astype(int) return one_hot_board + + +def pairwise(iterable): + """s -> (s0,s1), (s1,s2), (s2, s3), ...""" + a, b = itertools.tee(iterable) + next(b, None) + return zip(a, b) + + +class IllegalMove(Exception): + pass + + +class IllegalActionError(Exception): + pass diff --git a/zoo/game_2048/envs/test_game_2048_env.py b/zoo/game_2048/envs/test_game_2048_env.py new file mode 100644 index 000000000..98d6c5bb8 --- /dev/null +++ b/zoo/game_2048/envs/test_game_2048_env.py @@ -0,0 +1,70 @@ +import numpy as np +import pytest +from easydict import EasyDict + +from .game_2048_env import Game2048Env + + +# Create a Game2048 environment that will be used in the following tests. +@pytest.fixture +def env(): + # Configuration for the Game2048 environment + cfg = EasyDict(dict( + env_name="game_2048", + save_replay_gif=False, + replay_path_gif=None, + replay_path=None, + act_scale=True, + channel_last=True, + obs_type='raw_observation', # options=['raw_observation', 'dict_observation', 'array'] + reward_normalize=False, + reward_norm_scale=100, + reward_type='raw', # 'merged_tiles_plus_log_max_tile_num' + max_tile=int(2 ** 16), # 2**11=2048, 2**16=65536 + delay_reward_step=0, + prob_random_agent=0., + max_episode_steps=int(1e6), + is_collect=True, + ignore_legal_actions=True, + need_flatten=False, + )) + return Game2048Env(cfg) + + +# Test the initialization of the Game2048 environment. +def test_initialization(env): + assert isinstance(env, Game2048Env) + + +# Test the reset method of the Game2048 environment. +# Ensure that the shape of the observation is as expected. +def test_reset(env): + obs = env.reset() + assert obs.shape == (4, 4, 16) + + +# Test the step method of the Game2048 environment. +# Ensure that the shape of the observation, the type of the reward, +# the type of the done flag and the type of the info are as expected. +def test_step(env): + env.reset() + obs, reward, done, info = env.step(1) + assert obs.shape == (4, 4, 16) + assert isinstance(reward, np.ndarray) + assert isinstance(done, bool) + assert isinstance(info, dict) + + +# Test the render method of the Game2048 environment. +# Ensure that the shape of the rendered image is as expected. +def test_render(env): + env.reset() + env.render(mode='human') + env.render(mode='rgb_array_render') + env.save_render_gif() + +# Test the seed method of the Game2048 environment. +# Ensure that the random seed is set correctly. +def test_seed(env): + env.seed(0) + assert env.np_random.randn() != np.random.randn() From 6e13727c262a4c0962cef76cefd42a5413fcad89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= <2402552459@qq.com> Date: Wed, 9 Aug 2023 10:32:14 +0800 Subject: [PATCH 15/28] feature(pu): add stochastic muzero eval config --- lzero/entry/eval_muzero.py | 4 +- .../stochastic_muzero_2048_eval_config.py | 47 +++++++++++++++++++ zoo/game_2048/envs/game_2048_env.py | 27 +++++++---- 3 files changed, 67 insertions(+), 11 deletions(-) create mode 100644 zoo/game_2048/config/stochastic_muzero_2048_eval_config.py diff --git a/lzero/entry/eval_muzero.py b/lzero/entry/eval_muzero.py index f79a02ec2..83b592660 100644 --- a/lzero/entry/eval_muzero.py +++ b/lzero/entry/eval_muzero.py @@ -38,8 +38,8 @@ def eval_muzero( - policy (:obj:`Policy`): Converged policy. """ cfg, create_cfg = input_cfg - assert create_cfg.policy.type in ['efficientzero', 'muzero', 'sampled_efficientzero'], \ - "LightZero now only support the following algo.: 'efficientzero', 'muzero', 'sampled_efficientzero'" + assert create_cfg.policy.type in ['efficientzero', 'muzero', 'stochastic_muzero', 'gumbel_muzero', 'sampled_efficientzero'], \ + "LightZero now only support the following algo.: 'efficientzero', 'muzero', 'stochastic_muzero', 'gumbel_muzero', 'sampled_efficientzero'" if cfg.policy.cuda and torch.cuda.is_available(): cfg.policy.device = 'cuda' diff --git a/zoo/game_2048/config/stochastic_muzero_2048_eval_config.py b/zoo/game_2048/config/stochastic_muzero_2048_eval_config.py new file mode 100644 index 000000000..a1c4ab212 --- /dev/null +++ b/zoo/game_2048/config/stochastic_muzero_2048_eval_config.py @@ -0,0 +1,47 @@ +# According to the model you want to evaluate, import the corresponding config. +from lzero.entry import eval_muzero +import numpy as np + +if __name__ == "__main__": + """ + model_path (:obj:`Optional[str]`): The pretrained model path, which should + point to the ckpt file of the pretrained model, and an absolute path is recommended. + In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``. + """ + # Take the config of sampled efficientzero as an example + from stochastic_muzero_2048_config import main_config, create_config + + model_path = "/Users/puyuan/code/LightZero/data_stochastic_mz_ctree/game_2048_stochastic_muzero_ns100_upc200_rr0.0_bs512_chance-True-32_seed0/ckpt/ckpt_best.pth.tar" + + returns_mean_seeds = [] + returns_seeds = [] + seeds = [0] + num_episodes_each_seed = 1 + total_test_episodes = num_episodes_each_seed * len(seeds) + create_config.env_manager.type = 'base' # Visualization requires the 'type' to be set as base + main_config.env.evaluator_env_num = 1 # Visualization requires the 'env_num' to be set as 1 + main_config.env.n_evaluator_episode = total_test_episodes + main_config.env.save_replay_gif = True # Whether to save the gif replay, if save the video render_mode_human must to be True + main_config.env.replay_path_gif = './' + main_config.env.eval_max_episode_steps = int(1e6) # Adjust according to different environments + + for seed in seeds: + returns_mean, returns = eval_muzero( + [main_config, create_config], + seed=seed, + num_episodes_each_seed=num_episodes_each_seed, + print_seed_details=False, + model_path=model_path + ) + print(returns_mean, returns) + returns_mean_seeds.append(returns_mean) + returns_seeds.append(returns) + + returns_mean_seeds = np.array(returns_mean_seeds) + returns_seeds = np.array(returns_seeds) + + print("=" * 20) + print(f'We eval total {len(seeds)} seeds. In each seed, we eval {num_episodes_each_seed} episodes.') + print(f'In seeds {seeds}, returns_mean_seeds is {returns_mean_seeds}, returns is {returns_seeds}') + print('In all seeds, reward_mean:', returns_mean_seeds.mean()) + print("=" * 20) diff --git a/zoo/game_2048/envs/game_2048_env.py b/zoo/game_2048/envs/game_2048_env.py index a99e0a0b9..d6ea5272a 100644 --- a/zoo/game_2048/envs/game_2048_env.py +++ b/zoo/game_2048/envs/game_2048_env.py @@ -23,6 +23,7 @@ class Game2048Env(gym.Env): env_name="game_2048", save_replay_gif=False, replay_path_gif=None, + render_real_time=False, replay_path=None, act_scale=True, channel_last=True, @@ -50,9 +51,11 @@ def __init__(self, cfg: dict) -> None: self._cfg = cfg self._init_flag = False self._env_name = cfg.env_name - self._replay_path = cfg.get('replay_path', None) - self._replay_path_gif = cfg.get('replay_path_gif', None) - self._save_replay_gif = cfg.get('save_replay_gif', False) + self.replay_path_gif = cfg.replay_path_gif + self.save_replay_gif = cfg.save_replay_gif + self.render_real_time = cfg.render_real_time + + self._save_replay_count = 0 self.channel_last = cfg.channel_last self.obs_type = cfg.obs_type @@ -199,6 +202,8 @@ def step(self, action): if done: info['eval_episode_return'] = self._final_eval_reward + if self.save_replay_gif: + self.save_render_gif(gif_name_suffix='eval', replay_path_gif=self.replay_path_gif) return BaseEnvTimestep(observation, reward, done, info) @@ -290,9 +295,10 @@ def render(self, mode='human'): self.draw_tile(draw, x, y, o, fnt) # Instead of returning the image, we display it using pyplot - plt.imshow(np.asarray(pil_board)) - plt.draw() - # plt.pause(0.001) + if self.render_real_time: + plt.imshow(np.asarray(pil_board)) + plt.draw() + # plt.pause(0.001) # Append the frame to frames for gif self.frames.append(np.asarray(pil_board)) elif mode == 'human': @@ -322,7 +328,7 @@ def draw_tile(self, draw, x, y, o, fnt): 2048: (237, 194, 46), 4096: (237, 194, 46), 8192: (237, 194, 46), - 16384:(237, 194, 46), + 16384: (237, 194, 46), } if o: draw.rectangle([x * grid_size, y * grid_size, (x + 1) * grid_size, (y + 1) * grid_size], @@ -334,9 +340,12 @@ def draw_tile(self, draw, x, y, o, fnt): assert text_x_size < grid_size assert text_y_size < grid_size - def save_render_gif(self, gif_name_suffix: str = ''): + def save_render_gif(self, gif_name_suffix: str = '', replay_path_gif = None): # At the end of the episode, save the frames as a gif - imageio.mimsave(f'game_2048_{gif_name_suffix}.gif', self.frames) + if replay_path_gif is None: + imageio.mimsave(f'game_2048_{gif_name_suffix}.gif', self.frames) + else: + imageio.mimsave(replay_path_gif, self.frames) self.frames = [] # Implementation of game logic for 2048 From 6f519b6fd7826503f0a05d6a3b8f19f7d231b47e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= <2402552459@qq.com> Date: Wed, 9 Aug 2023 12:11:12 +0800 Subject: [PATCH 16/28] polish(pu): polish 2048 save_replay method --- zoo/game_2048/config/muzero_2048_config.py | 12 ++-- .../config/rule_based_2048_config.py | 10 +-- .../stochastic_muzero_2048_eval_config.py | 8 ++- zoo/game_2048/envs/game_2048_env.py | 70 ++++++++++++------- 4 files changed, 64 insertions(+), 36 deletions(-) diff --git a/zoo/game_2048/config/muzero_2048_config.py b/zoo/game_2048/config/muzero_2048_config.py index 5563e7cab..f50a5fee7 100644 --- a/zoo/game_2048/config/muzero_2048_config.py +++ b/zoo/game_2048/config/muzero_2048_config.py @@ -11,7 +11,9 @@ n_episode = 8 evaluator_env_num = 3 num_simulations = 50 -update_per_collect = 200 +# update_per_collect = 200 +update_per_collect = 50 + batch_size = 512 # max_env_step = int(5e6) max_env_step = int(1e6) @@ -30,7 +32,7 @@ # ============================================================== atari_muzero_config = dict( - exp_name=f'data_mz_ctree/game_2048_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_bs{batch_size}_adamw-wd1e-6_seed0', + exp_name=f'data_mz_ctree/game_2048_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_bs{batch_size}_adam-wd0_seed0', env=dict( stop_value=int(1e6), env_name=env_name, @@ -68,11 +70,13 @@ # lr_piecewise_constant_decay=True, # learning_rate=0.2, # init lr for manually decay schedule - optim_type='AdamW_nanoGPT', + # optim_type='AdamW_nanoGPT', + # optim_type='AdamW_official', + optim_type='Adam', lr_piecewise_constant_decay=False, learning_rate=3e-3, # (float) Weight decay for training policy network. - weight_decay=1e-6, + weight_decay=0, num_simulations=num_simulations, reanalyze_ratio=reanalyze_ratio, diff --git a/zoo/game_2048/config/rule_based_2048_config.py b/zoo/game_2048/config/rule_based_2048_config.py index e198b157a..651adc897 100644 --- a/zoo/game_2048/config/rule_based_2048_config.py +++ b/zoo/game_2048/config/rule_based_2048_config.py @@ -151,11 +151,13 @@ def generate(grid: np.array) -> np.array: # return new grid return grid + # Define game configuration config = EasyDict(dict( env_name="game_2048", - save_replay_gif=False, - replay_path_gif=None, + save_replay=False, + replay_format='mp4', + replay_name_suffix='ns100_s1', replay_path=None, act_scale=True, channel_last=True, @@ -194,13 +196,11 @@ def generate(grid: np.array) -> np.array: except Exception as e: print(f'Exception: {e}') print('total_step_number: {}'.format(step)) - game_2048_env.save_render_gif(gif_name_suffix='bot') + game_2048_env.save_render_gif(replay_name_suffix='bot') break step += 1 print(f"step: {step}, action: {action}, reward: {reward}, raw_reward: {info['raw_reward']}") game_2048_env.render(mode='human') - game_2048_env.render(mode='rgb_array_render') if done: print('total_step_number: {}'.format(step)) - game_2048_env.save_render_gif(gif_name_suffix='bot') break diff --git a/zoo/game_2048/config/stochastic_muzero_2048_eval_config.py b/zoo/game_2048/config/stochastic_muzero_2048_eval_config.py index a1c4ab212..f3899e3d8 100644 --- a/zoo/game_2048/config/stochastic_muzero_2048_eval_config.py +++ b/zoo/game_2048/config/stochastic_muzero_2048_eval_config.py @@ -21,9 +21,11 @@ create_config.env_manager.type = 'base' # Visualization requires the 'type' to be set as base main_config.env.evaluator_env_num = 1 # Visualization requires the 'env_num' to be set as 1 main_config.env.n_evaluator_episode = total_test_episodes - main_config.env.save_replay_gif = True # Whether to save the gif replay, if save the video render_mode_human must to be True - main_config.env.replay_path_gif = './' - main_config.env.eval_max_episode_steps = int(1e6) # Adjust according to different environments + main_config.env.save_replay = True # Whether to save the replay, if save the video render_mode_human must to be True + main_config.env.replay_format = 'mp4' + main_config.env.replay_name_suffix = 'ns100_s1' + main_config.env.replay_path = None + main_config.env.max_episode_steps = int(1e9) # Adjust according to different environments for seed in seeds: returns_mean, returns = eval_muzero( diff --git a/zoo/game_2048/envs/game_2048_env.py b/zoo/game_2048/envs/game_2048_env.py index d6ea5272a..6b55166fc 100644 --- a/zoo/game_2048/envs/game_2048_env.py +++ b/zoo/game_2048/envs/game_2048_env.py @@ -5,6 +5,9 @@ from typing import List import gym +import imageio +import matplotlib.font_manager as fm +import matplotlib.pyplot as plt import numpy as np from PIL import Image, ImageDraw, ImageFont from ding.envs import BaseEnvTimestep @@ -13,18 +16,17 @@ from easydict import EasyDict from gym import spaces from gym.utils import seeding -from six import StringIO -import matplotlib.pyplot as plt -import imageio + @ENV_REGISTRY.register('game_2048') class Game2048Env(gym.Env): config = dict( env_name="game_2048", - save_replay_gif=False, - replay_path_gif=None, - render_real_time=False, + save_replay=False, + replay_format='gif', + replay_name_suffix='eval', replay_path=None, + render_real_time=False, act_scale=True, channel_last=True, obs_type='raw_observation', # options=['raw_observation', 'dict_observation', 'array'] @@ -51,11 +53,12 @@ def __init__(self, cfg: dict) -> None: self._cfg = cfg self._init_flag = False self._env_name = cfg.env_name - self.replay_path_gif = cfg.replay_path_gif - self.save_replay_gif = cfg.save_replay_gif + self.replay_format = cfg.replay_format + self.replay_name_suffix = cfg.replay_name_suffix + self.replay_path = cfg.replay_path + self.save_replay = cfg.save_replay self.render_real_time = cfg.render_real_time - self._save_replay_count = 0 self.channel_last = cfg.channel_last self.obs_type = cfg.obs_type @@ -64,7 +67,7 @@ def __init__(self, cfg: dict) -> None: self.reward_norm_scale = cfg.reward_norm_scale assert self.reward_type in ['raw', 'merged_tiles_plus_log_max_tile_num'] assert self.reward_type == 'raw' or ( - self.reward_type == 'merged_tiles_plus_log_max_tile_num' and self.reward_normalize == False) + self.reward_type == 'merged_tiles_plus_log_max_tile_num' and self.reward_normalize == False) self.max_tile = cfg.max_tile # Define the maximum tile that will end the game (e.g. 2048). None means no limit. # This does not affect the state returned. @@ -128,6 +131,8 @@ def reset(self): observation = self.board else: observation = observation + if self.save_replay: + self.render(mode='rgb_array_render') return observation def step(self, action): @@ -135,8 +140,11 @@ def step(self, action): self.episode_length += 1 if action not in self.legal_actions: - raise IllegalActionError( - f"You input illegal action: {action}, the legal_actions are {self.legal_actions}. ") + logging.warning( + f"You input illegal action: {action}, the legal_actions are {self.legal_actions}. " + f"Now we randomly choice a action from self.legal_actions." + ) + action = np.random.choice(self.legal_actions) if self.reward_type == 'merged_tiles_plus_log_max_tile_num': empty_num1 = len(self.get_empty_location()) @@ -200,10 +208,13 @@ def step(self, action): reward = to_ndarray([reward]).astype(np.float32) info = {"raw_reward": raw_reward, "current_max_tile_num": self.highest()} + if self.save_replay: + self.render(mode='rgb_array_render') + if done: info['eval_episode_return'] = self._final_eval_reward - if self.save_replay_gif: - self.save_render_gif(gif_name_suffix='eval', replay_path_gif=self.replay_path_gif) + if self.save_replay: + self.save_render_output(replay_name_suffix=self.replay_name_suffix, replay_path=self.replay_path, format=self.replay_format) return BaseEnvTimestep(observation, reward, done, info) @@ -269,8 +280,9 @@ def move(self, direction, trial=False): return move_reward def set_illegal_move_reward(self, reward): - """Define the reward/penalty for performing an illegal move. Also need - to update the reward range for this.""" + """ + Define the reward/penalty for performing an illegal move. Also need to update the reward range for this. + """ # Guess that the maximum reward is also 2**squares though you'll probably never get that. # (assume that illegal move reward is the lowest value that can be returned # TODO: check that this is correct @@ -286,7 +298,8 @@ def render(self, mode='human'): pil_board = Image.new("RGB", (grid_size * 4, grid_size * 4)) draw = ImageDraw.Draw(pil_board) draw.rectangle([0, 0, 4 * grid_size, 4 * grid_size], grey) - fnt = ImageFont.truetype('Arial.ttf', 30) + fnt_path = fm.findfont(fm.FontProperties(family='DejaVu Sans')) + fnt = ImageFont.truetype(fnt_path, 30) for y in range(4): for x in range(4): @@ -337,15 +350,24 @@ def draw_tile(self, draw, x, y, o, fnt): text_x_size, text_y_size = bbox[2] - bbox[0], bbox[3] - bbox[1] draw.text((x * grid_size + (grid_size - text_x_size) // 2, y * grid_size + (grid_size - text_y_size) // 2), str(o), font=fnt, fill=white) - assert text_x_size < grid_size - assert text_y_size < grid_size + # assert text_x_size < grid_size + # assert text_y_size < grid_size - def save_render_gif(self, gif_name_suffix: str = '', replay_path_gif = None): - # At the end of the episode, save the frames as a gif - if replay_path_gif is None: - imageio.mimsave(f'game_2048_{gif_name_suffix}.gif', self.frames) + def save_render_output(self, replay_name_suffix: str = '', replay_path=None, format='gif'): + # At the end of the episode, save the frames + if replay_path is None: + filename = f'game_2048_{replay_name_suffix}.{format}' else: - imageio.mimsave(replay_path_gif, self.frames) + filename = f'{replay_path}.{format}' + + if format == 'gif': + imageio.mimsave(filename, self.frames, 'GIF') + elif format == 'mp4': + imageio.mimsave(filename, self.frames, 'MP4') + else: + raise ValueError("Unsupported format: {}".format(format)) + + logging.info("Saved output to {}".format(filename)) self.frames = [] # Implementation of game logic for 2048 From e5f6b0800be320b8fd01683885aeac2f8f788a2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= <2402552459@qq.com> Date: Wed, 9 Aug 2023 15:57:40 +0800 Subject: [PATCH 17/28] feature(pu): add num_of_possible_chance_tile option in 2048 env --- zoo/game_2048/config/muzero_2048_config.py | 15 +++--- .../config/stochastic_muzero_2048_config.py | 19 ++++--- zoo/game_2048/envs/game_2048_env.py | 49 +++++++++++++++++-- 3 files changed, 65 insertions(+), 18 deletions(-) diff --git a/zoo/game_2048/config/muzero_2048_config.py b/zoo/game_2048/config/muzero_2048_config.py index f50a5fee7..617a3fc87 100644 --- a/zoo/game_2048/config/muzero_2048_config.py +++ b/zoo/game_2048/config/muzero_2048_config.py @@ -11,13 +11,15 @@ n_episode = 8 evaluator_env_num = 3 num_simulations = 50 -# update_per_collect = 200 -update_per_collect = 50 +update_per_collect = 200 +# update_per_collect = 50 batch_size = 512 -# max_env_step = int(5e6) -max_env_step = int(1e6) +max_env_step = int(5e6) +# max_env_step = int(1e6) reanalyze_ratio = 0. +num_of_possible_chance_tile = 10 +chance_space_size = 16 * num_of_possible_chance_tile # collector_env_num = 1 # n_episode = 1 @@ -32,7 +34,7 @@ # ============================================================== atari_muzero_config = dict( - exp_name=f'data_mz_ctree/game_2048_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_bs{batch_size}_adam-wd0_seed0', + exp_name=f'data_mz_ctree/game_2048_nct-{num_of_possible_chance_tile}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_bs{batch_size}_adam-wd0_seed0', env=dict( stop_value=int(1e6), env_name=env_name, @@ -42,6 +44,7 @@ reward_normalize=False, reward_norm_scale=100, max_tile=int(2 ** 16), # 2**11=2048, 2**16=65536 + num_of_possible_chance_tile=num_of_possible_chance_tile, collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, n_evaluator_episode=evaluator_env_num, @@ -76,7 +79,7 @@ lr_piecewise_constant_decay=False, learning_rate=3e-3, # (float) Weight decay for training policy network. - weight_decay=0, + weight_decay=1e-4, num_simulations=num_simulations, reanalyze_ratio=reanalyze_ratio, diff --git a/zoo/game_2048/config/stochastic_muzero_2048_config.py b/zoo/game_2048/config/stochastic_muzero_2048_config.py index 69961a33e..3a5a149fa 100644 --- a/zoo/game_2048/config/stochastic_muzero_2048_config.py +++ b/zoo/game_2048/config/stochastic_muzero_2048_config.py @@ -1,3 +1,4 @@ +import numpy as np from easydict import EasyDict import sys sys.path.append('/mnt/nfs/puyuan/LightZero/zoo/game_2048') @@ -5,23 +6,25 @@ env_name = 'game_2048' action_space_size = 4 -# use_ture_chance_label_in_chance_encoder = True -use_ture_chance_label_in_chance_encoder = False +use_ture_chance_label_in_chance_encoder = True +# use_ture_chance_label_in_chance_encoder = False -chance_space_size = 32 # ============================================================== # begin of the most frequently changed config specified by the user # ============================================================== collector_env_num = 8 n_episode = 8 evaluator_env_num = 3 -# num_simulations = 100 # TODO(pu): 50 -num_simulations = 50 # TODO(pu): 50 +num_simulations = 100 # TODO(pu): 50 +# num_simulations = 50 # TODO(pu): 50 update_per_collect = 200 batch_size = 512 max_env_step = int(1e9) reanalyze_ratio = 0. +num_of_possible_chance_tile = 10 +chance_space_size = 16 * num_of_possible_chance_tile + # collector_env_num = 1 # n_episode = 1 @@ -36,7 +39,7 @@ # ============================================================== game_2048_stochastic_muzero_config = dict( - exp_name=f'data_stochastic_mz_ctree/game_2048_stochastic_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_bs{batch_size}_chance-{use_ture_chance_label_in_chance_encoder}-{chance_space_size}_seed0_1e9', + exp_name=f'data_stochastic_mz_ctree/game_2048_nct-{num_of_possible_chance_tile}_stochastic_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_bs{batch_size}_chance-{use_ture_chance_label_in_chance_encoder}-{chance_space_size}_seed0_1e9', env=dict( stop_value=int(1e6), env_name=env_name, @@ -46,6 +49,7 @@ reward_normalize=False, reward_scale=100, max_tile=int(2**16), # 2**11=2048, 2**16=65536 + num_of_possible_chance_tile=num_of_possible_chance_tile, collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, n_evaluator_episode=evaluator_env_num, @@ -79,10 +83,9 @@ optim_type='Adam', lr_piecewise_constant_decay=False, learning_rate=0.003, - # learning_rate=0.0003, + weight_decay=1e-4, num_simulations=num_simulations, reanalyze_ratio=reanalyze_ratio, - # ssl_loss_weight=2, # default is 0 ssl_loss_weight=0, # default is 0 n_episode=n_episode, eval_freq=int(2e3), diff --git a/zoo/game_2048/envs/game_2048_env.py b/zoo/game_2048/envs/game_2048_env.py index 6b55166fc..79b346f80 100644 --- a/zoo/game_2048/envs/game_2048_env.py +++ b/zoo/game_2048/envs/game_2048_env.py @@ -40,6 +40,9 @@ class Game2048Env(gym.Env): is_collect=True, ignore_legal_actions=True, need_flatten=False, + num_of_possible_chance_tile=2, + possible_tiles=np.array([2, 4]), + tile_probabilities=np.array([0.9, 0.1]), ) metadata = {'render.modes': ['human', 'rgb_array_render']} @@ -96,6 +99,14 @@ def __init__(self, cfg: dict) -> None: # Initialise the random seed of the gym environment. self.seed() self.frames = [] + self.num_of_possible_chance_tile = cfg.num_of_possible_chance_tile + self.possible_tiles = cfg.possible_tiles + self.tile_probabilities = cfg.tile_probabilities + if self.num_of_possible_chance_tile > 2: + self.possible_tiles = np.array([2**(i+1) for i in range(self.num_of_possible_chance_tile)]) + self.tile_probabilities = np.array([1/self.num_of_possible_chance_tile for _ in range(self.num_of_possible_chance_tile)]) + assert self.possible_tiles.shape[0] == self.tile_probabilities.shape[0] + assert np.sum(self.tile_probabilities) == 1 def reset(self): """Reset the game board-matrix and add 2 tiles.""" @@ -107,8 +118,12 @@ def reset(self): logging.debug("Adding tiles") # TODO(pu): why add_tiles twice? - self.add_random_2_4_tile() - self.add_random_2_4_tile() + if self.num_of_possible_chance_tile > 2: + self.add_random_tile(self.possible_tiles, self.tile_probabilities) + self.add_random_tile(self.possible_tiles, self.tile_probabilities) + elif self.num_of_possible_chance_tile == 2: + self.add_random_2_4_tile() + self.add_random_2_4_tile() action_mask = np.zeros(4, 'int8') action_mask[self.legal_actions] = 1 @@ -160,7 +175,10 @@ def step(self, action): self.episode_return += raw_reward assert raw_reward <= 2 ** (self.w * self.h) - self.add_random_2_4_tile() + if self.num_of_possible_chance_tile > 2: + self.add_random_tile(self.possible_tiles, self.tile_probabilities) + elif self.num_of_possible_chance_tile == 2: + self.add_random_2_4_tile() done = self.is_end() if self.reward_type == 'merged_tiles_plus_log_max_tile_num': reward_merged_tiles_plus_log_max_tile_num = float(reward_merged_tiles_plus_log_max_tile_num) @@ -363,7 +381,8 @@ def save_render_output(self, replay_name_suffix: str = '', replay_path=None, for if format == 'gif': imageio.mimsave(filename, self.frames, 'GIF') elif format == 'mp4': - imageio.mimsave(filename, self.frames, 'MP4') + imageio.mimsave(filename, self.frames, fps=30, codec='mpeg4') + else: raise ValueError("Unsupported format: {}".format(format)) @@ -395,6 +414,28 @@ def add_random_2_4_tile(self): self.board[empty[0], empty[1]] = tile_val + def add_random_tile(self, possible_tiles: np.array = np.array([2, 4]), tile_probabilities: np.array = np.array([0.9, 0.1])): + """Add a tile with a value from possible_tiles array according to given probabilities.""" + if len(possible_tiles) != len(tile_probabilities): + raise ValueError("Length of possible_tiles and tile_probabilities must be the same") + if np.sum(tile_probabilities) != 1: + raise ValueError("Sum of tile_probabilities must be 1") + + tile_val = self.np_random.choice(possible_tiles, 1, p=tile_probabilities)[0] + tile_idx = np.where(possible_tiles == tile_val)[0][0] # get the index of the tile value + empty_location = self.get_empty_location() + if empty_location.shape[0] == 0: + self.should_done = True + return + empty_idx = self.np_random.choice(empty_location.shape[0]) + empty = empty_location[empty_idx] + logging.debug("Adding %s at %s", tile_val, (empty[0], empty[1])) + + # set the chance outcome + self.chance_space_size = len(possible_tiles) * 16 # assuming a 4x4 board + self.chance = tile_idx * 16 + 4 * empty[0] + empty[1] + + self.board[empty[0], empty[1]] = tile_val def get_empty_location(self): """Return a 2d numpy array with the location of empty squares.""" return np.argwhere(self.board == 0) From 7d6f4f1d93f15e30d5def3618776232c33251995 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= <2402552459@qq.com> Date: Fri, 18 Aug 2023 23:57:37 +0800 Subject: [PATCH 18/28] polish(pu): delete collector filed in create config, move eval_config to entry directory --- lzero/mcts/buffer/game_segment.py | 4 +-- .../tests/atari_efficientzero_config_test.py | 5 +--- .../tictactoe_muzero_bot_mode_config_test.py | 5 +--- .../mcts/tree_search/mcts_ctree_stochastic.py | 4 +-- lzero/policy/muzero.py | 2 +- lzero/policy/stochastic_muzero.py | 10 +++---- .../config/atari_efficientzero_config.py | 4 --- .../config/atari_gumbel_muzero_config.py | 4 --- zoo/atari/config/atari_muzero_config.py | 5 ---- .../atari_sampled_efficientzero_config.py | 4 --- .../config/atari_stochastic_muzero_config.py | 4 --- zoo/atari/entry/__init__.py | 0 .../{config => entry}/atari_eval_config.py | 2 +- ...test_atari_sampled_efficientzero_config.py | 4 --- .../config/gomoku_muzero_bot_mode_config.py | 4 --- zoo/board_games/gomoku/entry/__init__.py | 0 .../gomoku_alphazero_eval_config.py | 0 .../gomoku_gumbel_muzero_eval_config.py | 0 .../gomoku_muzero_eval_config.py | 0 .../tictactoe_muzero_bot_mode_config.py | 4 --- .../config/tictactoe_muzero_sp_mode_config.py | 4 --- zoo/board_games/tictactoe/entry/__init__.py | 0 .../tictactoe_alphazero_eval_config.py | 0 .../tictactoe_muzero_eval_config.py | 0 ...alwalker_cont_disc_efficientzero_config.py | 4 --- ..._cont_disc_sampled_efficientzero_config.py | 4 --- zoo/box2d/bipedalwalker/entry/__init__.py | 0 .../bipedalwalker_eval_config.py | 0 ...arlander_cont_disc_efficientzero_config.py | 4 --- ..._cont_disc_sampled_efficientzero_config.py | 4 --- .../lunarlander_disc_efficientzero_config.py | 4 --- zoo/box2d/lunarlander/entry/__init__.py | 0 .../lunarlander_eval_config.py | 0 .../config/cartpole_efficientzero_config.py | 4 --- .../cartpole/config/cartpole_muzero_config.py | 4 --- .../cartpole_sampled_efficientzero_config.py | 4 --- .../cartpole_stochastic_muzero_config.py | 4 --- .../cartpole/entry/__init__.py | 0 .../{config => entry}/cartpole_eval_config.py | 0 ...pendulum_cont_disc_efficientzero_config.py | 4 --- .../pendulum_cont_disc_muzero_config.py | 4 --- ..._cont_disc_sampled_efficientzero_config.py | 4 --- .../pendulum/entry/__init__.py | 0 .../{config => entry}/pendulum_eval_config.py | 0 zoo/game_2048/__init__.py | 0 zoo/game_2048/config/muzero_2048_config.py | 19 ++---------- .../config/stochastic_muzero_2048_config.py | 30 +++++-------------- zoo/game_2048/entry/__init__.py | 0 .../rule_based_2048_config.py | 0 .../stochastic_muzero_2048_eval_config.py | 0 ...ujoco_disc_sampled_efficientzero_config.py | 4 --- .../mujoco_sampled_efficientzero_config.py | 4 --- 52 files changed, 22 insertions(+), 152 deletions(-) create mode 100644 zoo/atari/entry/__init__.py rename zoo/atari/{config => entry}/atari_eval_config.py (97%) create mode 100644 zoo/board_games/gomoku/entry/__init__.py rename zoo/board_games/gomoku/{config => entry}/gomoku_alphazero_eval_config.py (100%) rename zoo/board_games/gomoku/{config => entry}/gomoku_gumbel_muzero_eval_config.py (100%) rename zoo/board_games/gomoku/{config => entry}/gomoku_muzero_eval_config.py (100%) create mode 100644 zoo/board_games/tictactoe/entry/__init__.py rename zoo/board_games/tictactoe/{config => entry}/tictactoe_alphazero_eval_config.py (100%) rename zoo/board_games/tictactoe/{config => entry}/tictactoe_muzero_eval_config.py (100%) create mode 100644 zoo/box2d/bipedalwalker/entry/__init__.py rename zoo/box2d/bipedalwalker/{config => entry}/bipedalwalker_eval_config.py (100%) create mode 100644 zoo/box2d/lunarlander/entry/__init__.py rename zoo/box2d/lunarlander/{config => entry}/lunarlander_eval_config.py (100%) create mode 100644 zoo/classic_control/cartpole/entry/__init__.py rename zoo/classic_control/cartpole/{config => entry}/cartpole_eval_config.py (100%) create mode 100644 zoo/classic_control/pendulum/entry/__init__.py rename zoo/classic_control/pendulum/{config => entry}/pendulum_eval_config.py (100%) create mode 100644 zoo/game_2048/__init__.py create mode 100644 zoo/game_2048/entry/__init__.py rename zoo/game_2048/{config => entry}/rule_based_2048_config.py (100%) rename zoo/game_2048/{config => entry}/stochastic_muzero_2048_eval_config.py (100%) diff --git a/lzero/mcts/buffer/game_segment.py b/lzero/mcts/buffer/game_segment.py index cb2eaeaae..08a3d11b0 100644 --- a/lzero/mcts/buffer/game_segment.py +++ b/lzero/mcts/buffer/game_segment.py @@ -117,8 +117,6 @@ def get_obs(self) -> List: assert timestep_obs == timestep_reward, "timestep_obs: {}, timestep_reward: {}".format( timestep_obs, timestep_reward ) - # TODO: - timestep = timestep_obs timestep = timestep_reward stacked_obs = self.obs_segment[timestep:timestep + self.frame_stack_num] if self.config.transform2string: @@ -132,7 +130,7 @@ def append( reward: np.ndarray, action_mask: np.ndarray = None, to_play: int = -1, - chance: np.ndarray = 0, + chance: int = 0, ) -> None: """ Overview: diff --git a/lzero/mcts/tests/atari_efficientzero_config_test.py b/lzero/mcts/tests/atari_efficientzero_config_test.py index 91101d7da..5d3afd7ad 100644 --- a/lzero/mcts/tests/atari_efficientzero_config_test.py +++ b/lzero/mcts/tests/atari_efficientzero_config_test.py @@ -82,6 +82,7 @@ discount_factor=0.997, transform2string=False, lstm_horizon_len=5, + use_ture_chance_label_in_chance_encoder=False, ), ) atari_efficientzero_config = EasyDict(atari_efficientzero_config) @@ -97,10 +98,6 @@ type='efficientzero', import_names=['lzero.policy.efficientzero'], ), - collector=dict( - type='episode_muzero', - import_names=['lzero.worker.muzero_collector'], - ) ) atari_efficientzero_create_config = EasyDict(atari_efficientzero_create_config) create_config = atari_efficientzero_create_config diff --git a/lzero/mcts/tests/tictactoe_muzero_bot_mode_config_test.py b/lzero/mcts/tests/tictactoe_muzero_bot_mode_config_test.py index f5e7463cf..6369a2ea3 100644 --- a/lzero/mcts/tests/tictactoe_muzero_bot_mode_config_test.py +++ b/lzero/mcts/tests/tictactoe_muzero_bot_mode_config_test.py @@ -73,6 +73,7 @@ evaluator_env_num=evaluator_env_num, transform2string=False, lstm_horizon_len=5, + use_ture_chance_label_in_chance_encoder=False, ), ) tictactoe_muzero_config = EasyDict(tictactoe_muzero_config) @@ -88,10 +89,6 @@ type='muzero', import_names=['lzero.policy.muzero'], ), - collector=dict( - type='episode_muzero', - import_names=['lzero.worker.muzero_collector'], - ) ) tictactoe_muzero_create_config = EasyDict(tictactoe_muzero_create_config) create_config = tictactoe_muzero_create_config diff --git a/lzero/mcts/tree_search/mcts_ctree_stochastic.py b/lzero/mcts/tree_search/mcts_ctree_stochastic.py index f248d8586..a92249fe4 100644 --- a/lzero/mcts/tree_search/mcts_ctree_stochastic.py +++ b/lzero/mcts/tree_search/mcts_ctree_stochastic.py @@ -5,12 +5,12 @@ import torch from easydict import EasyDict -from lzero.policy import InverseScalarTransform, to_detach_cpu_numpy +from lzero.policy import InverseScalarTransform from lzero.mcts.ctree.ctree_stochastic_muzero import stochastic_mz_tree # ============================================================== -# MuZero +# Stochastic MuZero # ============================================================== diff --git a/lzero/policy/muzero.py b/lzero/policy/muzero.py index 44a328eb1..ad891a776 100644 --- a/lzero/policy/muzero.py +++ b/lzero/policy/muzero.py @@ -223,7 +223,7 @@ def _init_learn(self) -> None: self._optimizer = optim.Adam( self._model.parameters(), lr=self._cfg.learning_rate, weight_decay=self._cfg.weight_decay ) - elif self._cfg.optim_type == 'AdamW_official': + elif self._cfg.optim_type == 'AdamW': self._optimizer = optim.AdamW( self._model.parameters(), lr=self._cfg.learning_rate, weight_decay=self._cfg.weight_decay ) diff --git a/lzero/policy/stochastic_muzero.py b/lzero/policy/stochastic_muzero.py index 65bc6fbff..b375bba88 100644 --- a/lzero/policy/stochastic_muzero.py +++ b/lzero/policy/stochastic_muzero.py @@ -100,9 +100,9 @@ class StochasticMuZeroPolicy(Policy): # (int) Minibatch size for one gradient descent. batch_size=256, # (str) Optimizer for training policy network. ['SGD', 'Adam', 'AdamW'] - optim_type='SGD', + optim_type='Adam', # (float) Learning rate for training policy network. Ininitial lr for manually decay schedule. - learning_rate=0.2, + learning_rate=int(3e-3), # (int) Frequency of target network update. target_update_freq=100, # (float) Weight decay for training policy network. @@ -137,7 +137,7 @@ class StochasticMuZeroPolicy(Policy): ssl_loss_weight=0, # (bool) Whether to use piecewise constant learning rate decay. # i.e. lr: 0.2 -> 0.02 -> 0.002 - lr_piecewise_constant_decay=True, + lr_piecewise_constant_decay=False, # (int) The number of final training iterations to control lr decay, which is only used for manually decay. threshold_training_steps_for_final_lr=int(5e4), # (bool) Whether to use manually decayed temperature. @@ -209,7 +209,7 @@ def _init_learn(self) -> None: Overview: Learn mode init method. Called by ``self.__init__``. Ininitialize the learn model, optimizer and MCTS utils. """ - assert self._cfg.optim_type in ['SGD', 'Adam', 'AdamW_official', 'AdamW_nanoGPT'], self._cfg.optim_type + assert self._cfg.optim_type in ['SGD', 'Adam', 'AdamW', 'AdamW_nanoGPT'], self._cfg.optim_type # NOTE: in board_gmaes, for fixed lr 0.003, 'Adam' is better than 'SGD'. if self._cfg.optim_type == 'SGD': self._optimizer = optim.SGD( @@ -222,7 +222,7 @@ def _init_learn(self) -> None: self._optimizer = optim.Adam( self._model.parameters(), lr=self._cfg.learning_rate, weight_decay=self._cfg.weight_decay ) - elif self._cfg.optim_type == 'AdamW_official': + elif self._cfg.optim_type == 'AdamW': self._optimizer = optim.AdamW( self._model.parameters(), lr=self._cfg.learning_rate, weight_decay=self._cfg.weight_decay ) diff --git a/zoo/atari/config/atari_efficientzero_config.py b/zoo/atari/config/atari_efficientzero_config.py index 1de1800e2..aedd1edbd 100644 --- a/zoo/atari/config/atari_efficientzero_config.py +++ b/zoo/atari/config/atari_efficientzero_config.py @@ -91,10 +91,6 @@ type='efficientzero', import_names=['lzero.policy.efficientzero'], ), - collector=dict( - type='episode_muzero', - import_names=['lzero.worker.muzero_collector'], - ) ) atari_efficientzero_create_config = EasyDict(atari_efficientzero_create_config) create_config = atari_efficientzero_create_config diff --git a/zoo/atari/config/atari_gumbel_muzero_config.py b/zoo/atari/config/atari_gumbel_muzero_config.py index b4afaf14a..4d39eeeda 100644 --- a/zoo/atari/config/atari_gumbel_muzero_config.py +++ b/zoo/atari/config/atari_gumbel_muzero_config.py @@ -84,10 +84,6 @@ type='gumbel_muzero', import_names=['lzero.policy.gumbel_muzero'], ), - collector=dict( - type='episode_muzero', - import_names=['lzero.worker.muzero_collector'], - ) ) atari_gumbel_muzero_create_config = EasyDict(atari_gumbel_muzero_create_config) create_config = atari_gumbel_muzero_create_config diff --git a/zoo/atari/config/atari_muzero_config.py b/zoo/atari/config/atari_muzero_config.py index a62d44224..64ed37eaa 100644 --- a/zoo/atari/config/atari_muzero_config.py +++ b/zoo/atari/config/atari_muzero_config.py @@ -25,7 +25,6 @@ batch_size = 256 max_env_step = int(1e6) reanalyze_ratio = 0. - eps_greedy_exploration_in_collect = False # ============================================================== # end of the most frequently changed config specified by the user @@ -95,10 +94,6 @@ type='muzero', import_names=['lzero.policy.muzero'], ), - collector=dict( - type='episode_muzero', - import_names=['lzero.worker.muzero_collector'], - ) ) atari_muzero_create_config = EasyDict(atari_muzero_create_config) create_config = atari_muzero_create_config diff --git a/zoo/atari/config/atari_sampled_efficientzero_config.py b/zoo/atari/config/atari_sampled_efficientzero_config.py index 4b0c36901..682f85009 100644 --- a/zoo/atari/config/atari_sampled_efficientzero_config.py +++ b/zoo/atari/config/atari_sampled_efficientzero_config.py @@ -85,10 +85,6 @@ type='sampled_efficientzero', import_names=['lzero.policy.sampled_efficientzero'], ), - collector=dict( - type='episode_muzero', - import_names=['lzero.worker.muzero_collector'], - ) ) atari_sampled_efficientzero_create_config = EasyDict(atari_sampled_efficientzero_create_config) create_config = atari_sampled_efficientzero_create_config diff --git a/zoo/atari/config/atari_stochastic_muzero_config.py b/zoo/atari/config/atari_stochastic_muzero_config.py index 3854b859c..bf359b360 100644 --- a/zoo/atari/config/atari_stochastic_muzero_config.py +++ b/zoo/atari/config/atari_stochastic_muzero_config.py @@ -87,10 +87,6 @@ type='stochastic_muzero', import_names=['lzero.policy.stochastic_muzero'], ), - collector=dict( - type='episode_muzero', - import_names=['lzero.worker.muzero_collector'], - ) ) atari_stochastic_muzero_create_config = EasyDict(atari_stochastic_muzero_create_config) create_config = atari_stochastic_muzero_create_config diff --git a/zoo/atari/entry/__init__.py b/zoo/atari/entry/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/zoo/atari/config/atari_eval_config.py b/zoo/atari/entry/atari_eval_config.py similarity index 97% rename from zoo/atari/config/atari_eval_config.py rename to zoo/atari/entry/atari_eval_config.py index 6f975ac3c..ba897207a 100644 --- a/zoo/atari/config/atari_eval_config.py +++ b/zoo/atari/entry/atari_eval_config.py @@ -23,7 +23,7 @@ main_config.env.n_evaluator_episode = total_test_episodes main_config.env.render_mode_human = True # Whether to enable real-time rendering main_config.env.save_video = True # Whether to save the video, if save the video render_mode_human must to be True - main_config.env.save_path = './' + main_config.env.save_path = '../config/' main_config.env.eval_max_episode_steps=int(1e3) # Adjust according to different environments for seed in seeds: diff --git a/zoo/atari/tests/test_atari_sampled_efficientzero_config.py b/zoo/atari/tests/test_atari_sampled_efficientzero_config.py index 4d830c2f4..122a041f9 100644 --- a/zoo/atari/tests/test_atari_sampled_efficientzero_config.py +++ b/zoo/atari/tests/test_atari_sampled_efficientzero_config.py @@ -84,10 +84,6 @@ type='sampled_efficientzero', import_names=['lzero.policy.sampled_efficientzero'], ), - collector=dict( - type='episode_muzero', - import_names=['lzero.worker.muzero_collector'], - ) ) atari_sampled_efficientzero_create_config = EasyDict(atari_sampled_efficientzero_create_config) create_config = atari_sampled_efficientzero_create_config \ No newline at end of file diff --git a/zoo/board_games/gomoku/config/gomoku_muzero_bot_mode_config.py b/zoo/board_games/gomoku/config/gomoku_muzero_bot_mode_config.py index ef0c87000..487218418 100644 --- a/zoo/board_games/gomoku/config/gomoku_muzero_bot_mode_config.py +++ b/zoo/board_games/gomoku/config/gomoku_muzero_bot_mode_config.py @@ -79,10 +79,6 @@ type='muzero', import_names=['lzero.policy.muzero'], ), - collector=dict( - type='episode_muzero', - import_names=['lzero.worker.muzero_collector'], - ) ) gomoku_muzero_create_config = EasyDict(gomoku_muzero_create_config) create_config = gomoku_muzero_create_config diff --git a/zoo/board_games/gomoku/entry/__init__.py b/zoo/board_games/gomoku/entry/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/zoo/board_games/gomoku/config/gomoku_alphazero_eval_config.py b/zoo/board_games/gomoku/entry/gomoku_alphazero_eval_config.py similarity index 100% rename from zoo/board_games/gomoku/config/gomoku_alphazero_eval_config.py rename to zoo/board_games/gomoku/entry/gomoku_alphazero_eval_config.py diff --git a/zoo/board_games/gomoku/config/gomoku_gumbel_muzero_eval_config.py b/zoo/board_games/gomoku/entry/gomoku_gumbel_muzero_eval_config.py similarity index 100% rename from zoo/board_games/gomoku/config/gomoku_gumbel_muzero_eval_config.py rename to zoo/board_games/gomoku/entry/gomoku_gumbel_muzero_eval_config.py diff --git a/zoo/board_games/gomoku/config/gomoku_muzero_eval_config.py b/zoo/board_games/gomoku/entry/gomoku_muzero_eval_config.py similarity index 100% rename from zoo/board_games/gomoku/config/gomoku_muzero_eval_config.py rename to zoo/board_games/gomoku/entry/gomoku_muzero_eval_config.py diff --git a/zoo/board_games/tictactoe/config/tictactoe_muzero_bot_mode_config.py b/zoo/board_games/tictactoe/config/tictactoe_muzero_bot_mode_config.py index 265eeb8a5..6c3edd342 100644 --- a/zoo/board_games/tictactoe/config/tictactoe_muzero_bot_mode_config.py +++ b/zoo/board_games/tictactoe/config/tictactoe_muzero_bot_mode_config.py @@ -77,10 +77,6 @@ type='muzero', import_names=['lzero.policy.muzero'], ), - collector=dict( - type='episode_muzero', - import_names=['lzero.worker.muzero_collector'], - ) ) tictactoe_muzero_create_config = EasyDict(tictactoe_muzero_create_config) create_config = tictactoe_muzero_create_config diff --git a/zoo/board_games/tictactoe/config/tictactoe_muzero_sp_mode_config.py b/zoo/board_games/tictactoe/config/tictactoe_muzero_sp_mode_config.py index 4f8bdb594..de2e250bd 100644 --- a/zoo/board_games/tictactoe/config/tictactoe_muzero_sp_mode_config.py +++ b/zoo/board_games/tictactoe/config/tictactoe_muzero_sp_mode_config.py @@ -76,10 +76,6 @@ type='muzero', import_names=['lzero.policy.muzero'], ), - collector=dict( - type='episode_muzero', - import_names=['lzero.worker.muzero_collector'], - ) ) tictactoe_muzero_create_config = EasyDict(tictactoe_muzero_create_config) create_config = tictactoe_muzero_create_config diff --git a/zoo/board_games/tictactoe/entry/__init__.py b/zoo/board_games/tictactoe/entry/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/zoo/board_games/tictactoe/config/tictactoe_alphazero_eval_config.py b/zoo/board_games/tictactoe/entry/tictactoe_alphazero_eval_config.py similarity index 100% rename from zoo/board_games/tictactoe/config/tictactoe_alphazero_eval_config.py rename to zoo/board_games/tictactoe/entry/tictactoe_alphazero_eval_config.py diff --git a/zoo/board_games/tictactoe/config/tictactoe_muzero_eval_config.py b/zoo/board_games/tictactoe/entry/tictactoe_muzero_eval_config.py similarity index 100% rename from zoo/board_games/tictactoe/config/tictactoe_muzero_eval_config.py rename to zoo/board_games/tictactoe/entry/tictactoe_muzero_eval_config.py diff --git a/zoo/box2d/bipedalwalker/config/bipedalwalker_cont_disc_efficientzero_config.py b/zoo/box2d/bipedalwalker/config/bipedalwalker_cont_disc_efficientzero_config.py index dfd7effbd..7208b3a71 100644 --- a/zoo/box2d/bipedalwalker/config/bipedalwalker_cont_disc_efficientzero_config.py +++ b/zoo/box2d/bipedalwalker/config/bipedalwalker_cont_disc_efficientzero_config.py @@ -77,10 +77,6 @@ type='efficientzero', import_names=['lzero.policy.efficientzero'], ), - collector=dict( - type='episode_muzero', - import_names=['lzero.worker.muzero_collector'], - ) ) bipedalwalker_cont_disc_efficientzero_create_config = EasyDict(bipedalwalker_cont_disc_efficientzero_create_config) create_config = bipedalwalker_cont_disc_efficientzero_create_config diff --git a/zoo/box2d/bipedalwalker/config/bipedalwalker_cont_disc_sampled_efficientzero_config.py b/zoo/box2d/bipedalwalker/config/bipedalwalker_cont_disc_sampled_efficientzero_config.py index 098fc3adb..e8a5965c1 100644 --- a/zoo/box2d/bipedalwalker/config/bipedalwalker_cont_disc_sampled_efficientzero_config.py +++ b/zoo/box2d/bipedalwalker/config/bipedalwalker_cont_disc_sampled_efficientzero_config.py @@ -79,10 +79,6 @@ type='sampled_efficientzero', import_names=['lzero.policy.sampled_efficientzero'], ), - collector=dict( - type='episode_muzero', - import_names=['lzero.worker.muzero_collector'], - ) ) bipedalwalker_cont_disc_sampled_efficientzero_create_config = EasyDict( bipedalwalker_cont_disc_sampled_efficientzero_create_config diff --git a/zoo/box2d/bipedalwalker/entry/__init__.py b/zoo/box2d/bipedalwalker/entry/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/zoo/box2d/bipedalwalker/config/bipedalwalker_eval_config.py b/zoo/box2d/bipedalwalker/entry/bipedalwalker_eval_config.py similarity index 100% rename from zoo/box2d/bipedalwalker/config/bipedalwalker_eval_config.py rename to zoo/box2d/bipedalwalker/entry/bipedalwalker_eval_config.py diff --git a/zoo/box2d/lunarlander/config/lunarlander_cont_disc_efficientzero_config.py b/zoo/box2d/lunarlander/config/lunarlander_cont_disc_efficientzero_config.py index c05304480..b7f90b82b 100644 --- a/zoo/box2d/lunarlander/config/lunarlander_cont_disc_efficientzero_config.py +++ b/zoo/box2d/lunarlander/config/lunarlander_cont_disc_efficientzero_config.py @@ -75,10 +75,6 @@ type='efficientzero', import_names=['lzero.policy.efficientzero'], ), - collector=dict( - type='episode_muzero', - import_names=['lzero.worker.muzero_collector'], - ) ) lunarlander_cont_disc_efficientzero_create_config = EasyDict(lunarlander_cont_disc_efficientzero_create_config) create_config = lunarlander_cont_disc_efficientzero_create_config diff --git a/zoo/box2d/lunarlander/config/lunarlander_cont_disc_sampled_efficientzero_config.py b/zoo/box2d/lunarlander/config/lunarlander_cont_disc_sampled_efficientzero_config.py index d10e8b3b7..142f1ed82 100644 --- a/zoo/box2d/lunarlander/config/lunarlander_cont_disc_sampled_efficientzero_config.py +++ b/zoo/box2d/lunarlander/config/lunarlander_cont_disc_sampled_efficientzero_config.py @@ -77,10 +77,6 @@ type='sampled_efficientzero', import_names=['lzero.policy.sampled_efficientzero'], ), - collector=dict( - type='episode_muzero', - import_names=['lzero.worker.muzero_collector'], - ) ) lunarlander_cont_disc_sampled_efficientzero_create_config = EasyDict( lunarlander_cont_disc_sampled_efficientzero_create_config diff --git a/zoo/box2d/lunarlander/config/lunarlander_disc_efficientzero_config.py b/zoo/box2d/lunarlander/config/lunarlander_disc_efficientzero_config.py index 2efbebe6b..5f71321e4 100644 --- a/zoo/box2d/lunarlander/config/lunarlander_disc_efficientzero_config.py +++ b/zoo/box2d/lunarlander/config/lunarlander_disc_efficientzero_config.py @@ -70,10 +70,6 @@ type='efficientzero', import_names=['lzero.policy.efficientzero'], ), - collector=dict( - type='episode_muzero', - import_names=['lzero.worker.muzero_collector'], - ) ) lunarlander_disc_efficientzero_create_config = EasyDict(lunarlander_disc_efficientzero_create_config) create_config = lunarlander_disc_efficientzero_create_config diff --git a/zoo/box2d/lunarlander/entry/__init__.py b/zoo/box2d/lunarlander/entry/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/zoo/box2d/lunarlander/config/lunarlander_eval_config.py b/zoo/box2d/lunarlander/entry/lunarlander_eval_config.py similarity index 100% rename from zoo/box2d/lunarlander/config/lunarlander_eval_config.py rename to zoo/box2d/lunarlander/entry/lunarlander_eval_config.py diff --git a/zoo/classic_control/cartpole/config/cartpole_efficientzero_config.py b/zoo/classic_control/cartpole/config/cartpole_efficientzero_config.py index 95dc2fa17..705d4d73e 100644 --- a/zoo/classic_control/cartpole/config/cartpole_efficientzero_config.py +++ b/zoo/classic_control/cartpole/config/cartpole_efficientzero_config.py @@ -68,10 +68,6 @@ type='efficientzero', import_names=['lzero.policy.efficientzero'], ), - collector=dict( - type='episode_muzero', - import_names=['lzero.worker.muzero_collector'], - ) ) cartpole_efficientzero_create_config = EasyDict(cartpole_efficientzero_create_config) create_config = cartpole_efficientzero_create_config diff --git a/zoo/classic_control/cartpole/config/cartpole_muzero_config.py b/zoo/classic_control/cartpole/config/cartpole_muzero_config.py index 8bfd87d5a..86e25a60a 100644 --- a/zoo/classic_control/cartpole/config/cartpole_muzero_config.py +++ b/zoo/classic_control/cartpole/config/cartpole_muzero_config.py @@ -69,10 +69,6 @@ type='muzero', import_names=['lzero.policy.muzero'], ), - collector=dict( - type='episode_muzero', - import_names=['lzero.worker.muzero_collector'], - ) ) cartpole_muzero_create_config = EasyDict(cartpole_muzero_create_config) create_config = cartpole_muzero_create_config diff --git a/zoo/classic_control/cartpole/config/cartpole_sampled_efficientzero_config.py b/zoo/classic_control/cartpole/config/cartpole_sampled_efficientzero_config.py index 7f9822bd6..83f3c9769 100644 --- a/zoo/classic_control/cartpole/config/cartpole_sampled_efficientzero_config.py +++ b/zoo/classic_control/cartpole/config/cartpole_sampled_efficientzero_config.py @@ -72,10 +72,6 @@ type='sampled_efficientzero', import_names=['lzero.policy.sampled_efficientzero'], ), - collector=dict( - type='episode_muzero', - import_names=['lzero.worker.muzero_collector'], - ) ) cartpole_sampled_efficientzero_create_config = EasyDict(cartpole_sampled_efficientzero_create_config) create_config = cartpole_sampled_efficientzero_create_config diff --git a/zoo/classic_control/cartpole/config/cartpole_stochastic_muzero_config.py b/zoo/classic_control/cartpole/config/cartpole_stochastic_muzero_config.py index 3b30fc4b1..7a47cb95f 100644 --- a/zoo/classic_control/cartpole/config/cartpole_stochastic_muzero_config.py +++ b/zoo/classic_control/cartpole/config/cartpole_stochastic_muzero_config.py @@ -72,10 +72,6 @@ type='stochastic_muzero', import_names=['lzero.policy.stochastic_muzero'], ), - collector=dict( - type='episode_muzero', - import_names=['lzero.worker.muzero_collector'], - ) ) cartpole_stochastic_muzero_create_config = EasyDict(cartpole_stochastic_muzero_create_config) create_config = cartpole_stochastic_muzero_create_config diff --git a/zoo/classic_control/cartpole/entry/__init__.py b/zoo/classic_control/cartpole/entry/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/zoo/classic_control/cartpole/config/cartpole_eval_config.py b/zoo/classic_control/cartpole/entry/cartpole_eval_config.py similarity index 100% rename from zoo/classic_control/cartpole/config/cartpole_eval_config.py rename to zoo/classic_control/cartpole/entry/cartpole_eval_config.py diff --git a/zoo/classic_control/pendulum/config/pendulum_cont_disc_efficientzero_config.py b/zoo/classic_control/pendulum/config/pendulum_cont_disc_efficientzero_config.py index d239e87d6..83608c83a 100644 --- a/zoo/classic_control/pendulum/config/pendulum_cont_disc_efficientzero_config.py +++ b/zoo/classic_control/pendulum/config/pendulum_cont_disc_efficientzero_config.py @@ -67,10 +67,6 @@ type='efficientzero', import_names=['lzero.policy.efficientzero'], ), - collector=dict( - type='episode_muzero', - import_names=['lzero.worker.muzero_collector'], - ) ) pendulum_disc_efficientzero_create_config = EasyDict(pendulum_disc_efficientzero_create_config) create_config = pendulum_disc_efficientzero_create_config diff --git a/zoo/classic_control/pendulum/config/pendulum_cont_disc_muzero_config.py b/zoo/classic_control/pendulum/config/pendulum_cont_disc_muzero_config.py index 0d1e4f879..6fc671be1 100644 --- a/zoo/classic_control/pendulum/config/pendulum_cont_disc_muzero_config.py +++ b/zoo/classic_control/pendulum/config/pendulum_cont_disc_muzero_config.py @@ -72,10 +72,6 @@ type='muzero', import_names=['lzero.policy.muzero'], ), - collector=dict( - type='episode_muzero', - import_names=['lzero.worker.muzero_collector'], - ) ) pendulum_disc_muzero_create_config = EasyDict(pendulum_disc_muzero_create_config) create_config = pendulum_disc_muzero_create_config diff --git a/zoo/classic_control/pendulum/config/pendulum_cont_disc_sampled_efficientzero_config.py b/zoo/classic_control/pendulum/config/pendulum_cont_disc_sampled_efficientzero_config.py index f7c74437c..cba4d2f30 100644 --- a/zoo/classic_control/pendulum/config/pendulum_cont_disc_sampled_efficientzero_config.py +++ b/zoo/classic_control/pendulum/config/pendulum_cont_disc_sampled_efficientzero_config.py @@ -70,10 +70,6 @@ type='sampled_efficientzero', import_names=['lzero.policy.sampled_efficientzero'], ), - collector=dict( - type='episode_muzero', - import_names=['lzero.worker.muzero_collector'], - ) ) pendulum_sampled_efficientzero_create_config = EasyDict(pendulum_sampled_efficientzero_create_config) create_config = pendulum_sampled_efficientzero_create_config diff --git a/zoo/classic_control/pendulum/entry/__init__.py b/zoo/classic_control/pendulum/entry/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/zoo/classic_control/pendulum/config/pendulum_eval_config.py b/zoo/classic_control/pendulum/entry/pendulum_eval_config.py similarity index 100% rename from zoo/classic_control/pendulum/config/pendulum_eval_config.py rename to zoo/classic_control/pendulum/entry/pendulum_eval_config.py diff --git a/zoo/game_2048/__init__.py b/zoo/game_2048/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/zoo/game_2048/config/muzero_2048_config.py b/zoo/game_2048/config/muzero_2048_config.py index 617a3fc87..73ddb8cce 100644 --- a/zoo/game_2048/config/muzero_2048_config.py +++ b/zoo/game_2048/config/muzero_2048_config.py @@ -6,21 +6,18 @@ # ============================================================== env_name = 'game_2048' action_space_size = 4 - collector_env_num = 8 n_episode = 8 evaluator_env_num = 3 num_simulations = 50 update_per_collect = 200 -# update_per_collect = 50 - batch_size = 512 max_env_step = int(5e6) -# max_env_step = int(1e6) reanalyze_ratio = 0. -num_of_possible_chance_tile = 10 +num_of_possible_chance_tile = 2 chance_space_size = 16 * num_of_possible_chance_tile +# debug config # collector_env_num = 1 # n_episode = 1 # evaluator_env_num = 1 @@ -61,7 +58,6 @@ mcts_ctree=True, gumbel_algo=False, cuda=True, - env_type='not_board_games', game_segment_length=200, update_per_collect=update_per_collect, batch_size=batch_size, @@ -69,18 +65,11 @@ discount_factor=0.999, manual_temperature_decay=True, threshold_training_steps_for_final_temperature=int(1e5), - # optim_type='SGD', - # lr_piecewise_constant_decay=True, - # learning_rate=0.2, # init lr for manually decay schedule - - # optim_type='AdamW_nanoGPT', - # optim_type='AdamW_official', optim_type='Adam', lr_piecewise_constant_decay=False, learning_rate=3e-3, # (float) Weight decay for training policy network. weight_decay=1e-4, - num_simulations=num_simulations, reanalyze_ratio=reanalyze_ratio, ssl_loss_weight=0, # default is 0 @@ -104,10 +93,6 @@ type='muzero', import_names=['lzero.policy.muzero'], ), - collector=dict( - type='episode_muzero', - import_names=['lzero.worker.muzero_collector'], - ) ) atari_muzero_create_config = EasyDict(atari_muzero_create_config) create_config = atari_muzero_create_config diff --git a/zoo/game_2048/config/stochastic_muzero_2048_config.py b/zoo/game_2048/config/stochastic_muzero_2048_config.py index 3a5a149fa..15ae2999a 100644 --- a/zoo/game_2048/config/stochastic_muzero_2048_config.py +++ b/zoo/game_2048/config/stochastic_muzero_2048_config.py @@ -1,31 +1,25 @@ -import numpy as np from easydict import EasyDict -import sys -sys.path.append('/mnt/nfs/puyuan/LightZero/zoo/game_2048') -# export PYTHONPATH='/mnt/nfs/puyuan/LightZero/zoo/game_2048':$PYTHONPATH -env_name = 'game_2048' -action_space_size = 4 -use_ture_chance_label_in_chance_encoder = True -# use_ture_chance_label_in_chance_encoder = False # ============================================================== # begin of the most frequently changed config specified by the user # ============================================================== +env_name = 'game_2048' +action_space_size = 4 +use_ture_chance_label_in_chance_encoder = True +# use_ture_chance_label_in_chance_encoder = False collector_env_num = 8 n_episode = 8 evaluator_env_num = 3 -num_simulations = 100 # TODO(pu): 50 -# num_simulations = 50 # TODO(pu): 50 - +num_simulations = 100 update_per_collect = 200 batch_size = 512 max_env_step = int(1e9) reanalyze_ratio = 0. -num_of_possible_chance_tile = 10 +num_of_possible_chance_tile = 2 chance_space_size = 16 * num_of_possible_chance_tile - +# debug config # collector_env_num = 1 # n_episode = 1 # evaluator_env_num = 1 @@ -39,7 +33,7 @@ # ============================================================== game_2048_stochastic_muzero_config = dict( - exp_name=f'data_stochastic_mz_ctree/game_2048_nct-{num_of_possible_chance_tile}_stochastic_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_bs{batch_size}_chance-{use_ture_chance_label_in_chance_encoder}-{chance_space_size}_seed0_1e9', + exp_name=f'data_stochastic_mz_ctree/game_2048_nct-{num_of_possible_chance_tile}_stochastic_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_bs{batch_size}_chance-{use_ture_chance_label_in_chance_encoder}-{chance_space_size}_seed0', env=dict( stop_value=int(1e6), env_name=env_name, @@ -70,16 +64,12 @@ mcts_ctree=True, gumbel_algo=False, cuda=True, - env_type='not_board_games', game_segment_length=200, update_per_collect=update_per_collect, batch_size=batch_size, td_steps=10, discount_factor=0.999, manual_temperature_decay=True, - # optim_type='SGD', - # lr_piecewise_constant_decay=True, - # learning_rate=0.2, # init lr for manually decay schedule optim_type='Adam', lr_piecewise_constant_decay=False, learning_rate=0.003, @@ -107,10 +97,6 @@ type='stochastic_muzero', import_names=['lzero.policy.stochastic_muzero'], ), - collector=dict( - type='episode_muzero', - import_names=['lzero.worker.muzero_collector'], - ) ) game_2048_stochastic_muzero_create_config = EasyDict(game_2048_stochastic_muzero_create_config) create_config = game_2048_stochastic_muzero_create_config diff --git a/zoo/game_2048/entry/__init__.py b/zoo/game_2048/entry/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/zoo/game_2048/config/rule_based_2048_config.py b/zoo/game_2048/entry/rule_based_2048_config.py similarity index 100% rename from zoo/game_2048/config/rule_based_2048_config.py rename to zoo/game_2048/entry/rule_based_2048_config.py diff --git a/zoo/game_2048/config/stochastic_muzero_2048_eval_config.py b/zoo/game_2048/entry/stochastic_muzero_2048_eval_config.py similarity index 100% rename from zoo/game_2048/config/stochastic_muzero_2048_eval_config.py rename to zoo/game_2048/entry/stochastic_muzero_2048_eval_config.py diff --git a/zoo/mujoco/config/mujoco_disc_sampled_efficientzero_config.py b/zoo/mujoco/config/mujoco_disc_sampled_efficientzero_config.py index 9852c53e4..1a0565057 100644 --- a/zoo/mujoco/config/mujoco_disc_sampled_efficientzero_config.py +++ b/zoo/mujoco/config/mujoco_disc_sampled_efficientzero_config.py @@ -102,10 +102,6 @@ type='sampled_efficientzero', import_names=['lzero.policy.sampled_efficientzero'], ), - collector=dict( - type='episode_muzero', - import_names=['lzero.worker.muzero_collector'], - ) ) mujoco_disc_sampled_efficientzero_create_config = EasyDict(mujoco_disc_sampled_efficientzero_create_config) create_config = mujoco_disc_sampled_efficientzero_create_config diff --git a/zoo/mujoco/config/mujoco_sampled_efficientzero_config.py b/zoo/mujoco/config/mujoco_sampled_efficientzero_config.py index ab5916c8c..c7cc30c0b 100644 --- a/zoo/mujoco/config/mujoco_sampled_efficientzero_config.py +++ b/zoo/mujoco/config/mujoco_sampled_efficientzero_config.py @@ -102,10 +102,6 @@ type='sampled_efficientzero', import_names=['lzero.policy.sampled_efficientzero'], ), - collector=dict( - type='episode_muzero', - import_names=['lzero.worker.muzero_collector'], - ) ) mujoco_sampled_efficientzero_create_config = EasyDict(mujoco_sampled_efficientzero_create_config) create_config = mujoco_sampled_efficientzero_create_config From 1f82928876e84e1d71f46125f803398662347041 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= <2402552459@qq.com> Date: Sat, 19 Aug 2023 18:33:23 +0800 Subject: [PATCH 19/28] sync code --- zoo/board_games/alphabeta_pruning_bot.py | 9 ++++++++- zoo/board_games/tictactoe/envs/tictactoe_env.py | 9 ++++++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/zoo/board_games/alphabeta_pruning_bot.py b/zoo/board_games/alphabeta_pruning_bot.py index 9ebf63125..1f456b682 100644 --- a/zoo/board_games/alphabeta_pruning_bot.py +++ b/zoo/board_games/alphabeta_pruning_bot.py @@ -33,6 +33,8 @@ def expand(self): else: next_start_player_index = 0 if self.is_terminal_node is False: + # Ensure self.legal_actions is valid before the loop + # self.legal_actions = self.env.get_legal_actions(self.board, self.start_player_index) while len(self.legal_actions) > 0: action = self.legal_actions.pop(0) board, legal_actions = self.env.simulate_action_v2(self.board, self.start_player_index, action) @@ -151,7 +153,6 @@ def get_best_action(self, board, player_index, depth=999): if __name__ == "__main__": import time - """ ##### TicTacToe ##### from zoo.board_games.tictactoe.envs.tictactoe_env import TicTacToeEnv cfg = dict( @@ -160,6 +161,8 @@ def get_best_action(self, board, player_index, depth=999): battle_mode='self_play_mode', agent_vs_human=False, bot_action_type='alpha_beta_pruning', # {'v0', 'alpha_beta_pruning'} + channel_last=True, + scale=True, ) env = TicTacToeEnv(EasyDict(cfg)) player_0 = AlphaBetaPruningBot(TicTacToeEnv, cfg, 'player 1') # player_index = 0, player = 1 @@ -216,6 +219,8 @@ def get_best_action(self, board, player_index, depth=999): channel_last=True, agent_vs_human=False, bot_action_type='alpha_beta_pruning', # {'v0', 'alpha_beta_pruning'} + prob_random_action_in_bot=0., + check_action_to_connect4_in_bot_v0=False, ) env = GomokuEnv(EasyDict(cfg)) player_0 = AlphaBetaPruningBot(GomokuEnv, cfg, 'player 1') # player_index = 0, player = 1 @@ -261,3 +266,5 @@ def get_best_action(self, board, player_index, depth=999): assert env.get_done_winner()[0] is False, env.get_done_winner()[1] == -1 # assert env.get_done_winner()[0] is True, env.get_done_winner()[1] == 2 + """ + diff --git a/zoo/board_games/tictactoe/envs/tictactoe_env.py b/zoo/board_games/tictactoe/envs/tictactoe_env.py index b5afc9f14..303365dfe 100644 --- a/zoo/board_games/tictactoe/envs/tictactoe_env.py +++ b/zoo/board_games/tictactoe/envs/tictactoe_env.py @@ -121,8 +121,15 @@ def reset(self, start_player_index=0, init_state=None): self.board = np.array(copy.deepcopy(init_state), dtype="int32") else: self.board = np.zeros((self.board_size, self.board_size), dtype="int32") + action_mask = np.zeros(self.total_num_actions, 'int8') - action_mask[self.legal_actions] = 1 + + # TODO(pu): debug + legal_actions = _legal_actions_func_lru(tuple(map(tuple, self.board))) + action_mask[legal_actions] = 1 + + # action_mask[self.legal_actions] = 1 + if self.battle_mode == 'play_with_bot_mode' or self.battle_mode == 'eval_mode': # In ``play_with_bot_mode`` and ``eval_mode``, we need to set the "to_play" parameter in the "obs" dict to -1, # because we don't take into account the alternation between players. From 3a00a35f8b6d95734a2271c3dac44c86c5f99398 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= <2402552459@qq.com> Date: Tue, 5 Sep 2023 11:07:10 +0800 Subject: [PATCH 20/28] polish(pu): polish 2048 rule_bot move method, polish 2048 env, polish stochastic muzero game buffer --- .../buffer/game_buffer_stochastic_muzero.py | 575 +---------------- lzero/mcts/ptree/ptree_stochastic_mz.py | 2 - .../mcts/tree_search/mcts_ptree_stochastic.py | 25 +- .../config/atari_stochastic_muzero_config.py | 23 +- zoo/atari/entry/atari_eval_config.py | 12 +- .../tictactoe/envs/tictactoe_env.py | 7 +- .../cartpole_stochastic_muzero_config.py | 3 +- zoo/game_2048/config/muzero_2048_config.py | 6 +- .../config/stochastic_muzero_2048_config.py | 9 +- zoo/game_2048/entry/rule_based_2048_config.py | 32 +- zoo/game_2048/envs/game_2048_env.py | 592 +++++++++--------- zoo/game_2048/envs/test_game_2048_env.py | 101 ++- 12 files changed, 476 insertions(+), 911 deletions(-) diff --git a/lzero/mcts/buffer/game_buffer_stochastic_muzero.py b/lzero/mcts/buffer/game_buffer_stochastic_muzero.py index d8b7dbf4c..c068f265e 100644 --- a/lzero/mcts/buffer/game_buffer_stochastic_muzero.py +++ b/lzero/mcts/buffer/game_buffer_stochastic_muzero.py @@ -1,24 +1,17 @@ -from typing import Any, List, Tuple, Union, TYPE_CHECKING, Optional +from typing import Any, Tuple import numpy as np -import torch from ding.utils import BUFFER_REGISTRY -from lzero.mcts.tree_search.mcts_ctree_stochastic import StochasticMuZeroMCTSCtree as MCTSCtree -from lzero.mcts.tree_search.mcts_ptree_stochastic import StochasticMuZeroMCTSPtree as MCTSPtree from lzero.mcts.utils import prepare_observation -from lzero.policy import to_detach_cpu_numpy, concat_output, concat_output_value, inverse_scalar_transform -from .game_buffer import GameBuffer - -if TYPE_CHECKING: - from lzero.policy import MuZeroPolicy, EfficientZeroPolicy, SampledEfficientZeroPolicy +from .game_buffer_muzero import MuZeroGameBuffer @BUFFER_REGISTRY.register('game_buffer_stochastic_muzero') -class StochasticMuZeroGameBuffer(GameBuffer): +class StochasticMuZeroGameBuffer(MuZeroGameBuffer): """ Overview: - The specific game buffer for MuZero policy. + The specific game buffer for Stochastic MuZero policy. """ def __init__(self, cfg: dict): @@ -48,49 +41,6 @@ def __init__(self, cfg: dict): self.game_pos_priorities = [] self.game_segment_game_pos_look_up = [] - def sample( - self, batch_size: int, policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"] - ) -> List[Any]: - """ - Overview: - sample data from ``GameBuffer`` and prepare the current and target batch for training. - Arguments: - - batch_size (:obj:`int`): batch size. - - policy (:obj:`Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"]`): policy. - Returns: - - train_data (:obj:`List`): List of train data, including current_batch and target_batch. - """ - policy._target_model.to(self._cfg.device) - policy._target_model.eval() - - # obtain the current_batch and prepare target context - reward_value_context, policy_re_context, policy_non_re_context, current_batch = self._make_batch( - batch_size, self._cfg.reanalyze_ratio - ) - # target reward, target value - batch_rewards, batch_target_values = self._compute_target_reward_value( - reward_value_context, policy._target_model - ) - # target policy - batch_target_policies_re = self._compute_target_policy_reanalyzed(policy_re_context, policy._target_model) - batch_target_policies_non_re = self._compute_target_policy_non_reanalyzed( - policy_non_re_context, self._cfg.model.action_space_size - ) - - # fusion of batch_target_policies_re and batch_target_policies_non_re to batch_target_policies - if 0 < self._cfg.reanalyze_ratio < 1: - batch_target_policies = np.concatenate([batch_target_policies_re, batch_target_policies_non_re]) - elif self._cfg.reanalyze_ratio == 1: - batch_target_policies = batch_target_policies_re - elif self._cfg.reanalyze_ratio == 0: - batch_target_policies = batch_target_policies_non_re - - target_batch = [batch_rewards, batch_target_values, batch_target_policies] - - # a batch contains the current_batch and the target_batch - train_data = [current_batch, target_batch] - return train_data - def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: """ Overview: @@ -119,10 +69,10 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: pos_in_game_segment = pos_in_game_segment_list[i] actions_tmp = game.action_segment[pos_in_game_segment:pos_in_game_segment + - self._cfg.num_unroll_steps].tolist() + self._cfg.num_unroll_steps].tolist() if self._cfg.use_ture_chance_label_in_chance_encoder: chances_tmp = game.chance_segment[1 + pos_in_game_segment:1 + pos_in_game_segment + - self._cfg.num_unroll_steps].tolist() + self._cfg.num_unroll_steps].tolist() # add mask for invalid actions (out of trajectory) mask_tmp = [1. for i in range(len(actions_tmp))] mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps - len(mask_tmp))] @@ -196,516 +146,3 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: context = reward_value_context, policy_re_context, policy_non_re_context, current_batch return context - - def _prepare_reward_value_context( - self, batch_index_list: List[str], game_segment_list: List[Any], pos_in_game_segment_list: List[Any], - total_transitions: int - ) -> List[Any]: - """ - Overview: - prepare the context of rewards and values for calculating TD value target in reanalyzing part. - Arguments: - - batch_index_list (:obj:`list`): the index of start transition of sampled minibatch in replay buffer - - game_segment_list (:obj:`list`): list of game segments - - pos_in_game_segment_list (:obj:`list`): list of transition index in game_segment - - total_transitions (:obj:`int`): number of collected transitions - Returns: - - reward_value_context (:obj:`list`): value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, - td_steps_list, action_mask_segment, to_play_segment - """ - zero_obs = game_segment_list[0].zero_obs() - value_obs_list = [] - # the value is valid or not (out of game_segment) - value_mask = [] - rewards_list = [] - game_segment_lens = [] - # for board games - action_mask_segment, to_play_segment = [], [] - - td_steps_list = [] - for game_segment, state_index, idx in zip(game_segment_list, pos_in_game_segment_list, batch_index_list): - game_segment_len = len(game_segment) - game_segment_lens.append(game_segment_len) - - td_steps = np.clip(self._cfg.td_steps, 1, max(1, game_segment_len - state_index)).astype(np.int32) - - # prepare the corresponding observations for bootstrapped values o_{t+k} - # o[t+ td_steps, t + td_steps + stack frames + num_unroll_steps] - # t=2+3 -> o[2+3, 2+3+4+5] -> o[5, 14] - game_obs = game_segment.get_unroll_obs(state_index + td_steps, self._cfg.num_unroll_steps) - - rewards_list.append(game_segment.reward_segment) - - # for board games - action_mask_segment.append(game_segment.action_mask_segment) - to_play_segment.append(game_segment.to_play_segment) - - for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): - # get the bootstrapped target obs - td_steps_list.append(td_steps) - # index of bootstrapped obs o_{t+td_steps} - bootstrap_index = current_index + td_steps - - if bootstrap_index < game_segment_len: - value_mask.append(1) - # beg_index = bootstrap_index - (state_index + td_steps), max of beg_index is num_unroll_steps - beg_index = current_index - state_index - end_index = beg_index + self._cfg.model.frame_stack_num - # the stacked obs in time t - obs = game_obs[beg_index:end_index] - else: - value_mask.append(0) - obs = zero_obs - - value_obs_list.append(obs) - - reward_value_context = [ - value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list, - action_mask_segment, to_play_segment - ] - return reward_value_context - - def _prepare_policy_non_reanalyzed_context( - self, batch_index_list: List[int], game_segment_list: List[Any], pos_in_game_segment_list: List[int] - ) -> List[Any]: - """ - Overview: - prepare the context of policies for calculating policy target in non-reanalyzing part, just return the policy in self-play - Arguments: - - batch_index_list (:obj:`list`): the index of start transition of sampled minibatch in replay buffer - - game_segment_list (:obj:`list`): list of game segments - - pos_in_game_segment_list (:obj:`list`): list transition index in game - Returns: - - policy_non_re_context (:obj:`list`): pos_in_game_segment_list, child_visits, game_segment_lens, action_mask_segment, to_play_segment - """ - child_visits = [] - game_segment_lens = [] - # for board games - action_mask_segment, to_play_segment = [], [] - - for game_segment, state_index, idx in zip(game_segment_list, pos_in_game_segment_list, batch_index_list): - game_segment_len = len(game_segment) - game_segment_lens.append(game_segment_len) - # for board games - action_mask_segment.append(game_segment.action_mask_segment) - to_play_segment.append(game_segment.to_play_segment) - - child_visits.append(game_segment.child_visit_segment) - - policy_non_re_context = [ - pos_in_game_segment_list, child_visits, game_segment_lens, action_mask_segment, to_play_segment - ] - return policy_non_re_context - - def _prepare_policy_reanalyzed_context( - self, batch_index_list: List[str], game_segment_list: List[Any], pos_in_game_segment_list: List[str] - ) -> List[Any]: - """ - Overview: - prepare the context of policies for calculating policy target in reanalyzing part. - Arguments: - - batch_index_list (:obj:'list'): start transition index in the replay buffer - - game_segment_list (:obj:'list'): list of game segments - - pos_in_game_segment_list (:obj:'list'): position of transition index in one game history - Returns: - - policy_re_context (:obj:`list`): policy_obs_list, policy_mask, pos_in_game_segment_list, indices, - child_visits, game_segment_lens, action_mask_segment, to_play_segment - """ - zero_obs = game_segment_list[0].zero_obs() - with torch.no_grad(): - # for policy - policy_obs_list = [] - policy_mask = [] - # 0 -> Invalid target policy for padding outside of game segments, - # 1 -> Previous target policy for game segments. - rewards, child_visits, game_segment_lens = [], [], [] - # for board games - action_mask_segment, to_play_segment = [], [] - for game_segment, state_index in zip(game_segment_list, pos_in_game_segment_list): - game_segment_len = len(game_segment) - game_segment_lens.append(game_segment_len) - rewards.append(game_segment.reward_segment) - # for board games - action_mask_segment.append(game_segment.action_mask_segment) - to_play_segment.append(game_segment.to_play_segment) - - child_visits.append(game_segment.child_visit_segment) - # prepare the corresponding observations - game_obs = game_segment.get_unroll_obs(state_index, self._cfg.num_unroll_steps) - for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): - - if current_index < game_segment_len: - policy_mask.append(1) - beg_index = current_index - state_index - end_index = beg_index + self._cfg.model.frame_stack_num - obs = game_obs[beg_index:end_index] - else: - policy_mask.append(0) - obs = zero_obs - policy_obs_list.append(obs) - - policy_re_context = [ - policy_obs_list, policy_mask, pos_in_game_segment_list, batch_index_list, child_visits, game_segment_lens, - action_mask_segment, to_play_segment - ] - return policy_re_context - - def _compute_target_reward_value(self, reward_value_context: List[Any], model: Any) -> Tuple[Any, Any]: - """ - Overview: - prepare reward and value targets from the context of rewards and values. - Arguments: - - reward_value_context (:obj:'list'): the reward value context - - model (:obj:'torch.tensor'):model of the target model - Returns: - - batch_value_prefixs (:obj:'np.ndarray): batch of value prefix - - batch_target_values (:obj:'np.ndarray): batch of value estimation - """ - value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list, action_mask_segment, \ - to_play_segment = reward_value_context # noqa - # transition_batch_size = game_segment_batch_size * (num_unroll_steps+1) - transition_batch_size = len(value_obs_list) - game_segment_batch_size = len(pos_in_game_segment_list) - - to_play, action_mask = self._preprocess_to_play_and_action_mask( - game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list - ) - if self._cfg.model.continuous_action_space is True: - # when the action space of the environment is continuous, action_mask[:] is None. - action_mask = [ - list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size) - ] - # NOTE: in continuous action space env: we set all legal_actions as -1 - legal_actions = [ - [-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size) - ] - else: - legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)] - - batch_target_values, batch_rewards = [], [] - with torch.no_grad(): - value_obs_list = prepare_observation(value_obs_list, self._cfg.model.model_type) - # split a full batch into slices of mini_infer_size: to save the GPU memory for more GPU actors - slices = int(np.ceil(transition_batch_size / self._cfg.mini_infer_size)) - network_output = [] - for i in range(slices): - beg_index = self._cfg.mini_infer_size * i - end_index = self._cfg.mini_infer_size * (i + 1) - - m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device).float() - - # calculate the target value - m_output = model.initial_inference(m_obs) - - if not model.training: - # if not in training, obtain the scalars of the value/reward - [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( - [ - m_output.latent_state, - inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), - m_output.policy_logits - ] - ) - - network_output.append(m_output) - - # concat the output slices after model inference - if self._cfg.use_root_value: - # use the root values from MCTS, as in EfficiientZero - # the root values have limited improvement but require much more GPU actors; - _, reward_pool, policy_logits_pool, latent_state_roots = concat_output(network_output, data_type='muzero') - reward_pool = reward_pool.squeeze().tolist() - policy_logits_pool = policy_logits_pool.tolist() - noises = [ - np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) - ).astype(np.float32).tolist() for j in range(transition_batch_size) - ] - if self._cfg.mcts_ctree: - # cpp mcts_tree - roots = MCTSCtree.roots(transition_batch_size, legal_actions) - roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) - # do MCTS for a new policy with the recent target model - MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play) - else: - # python mcts_tree - roots = MCTSPtree.roots(transition_batch_size, legal_actions) - roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) - # do MCTS for a new policy with the recent target model - MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play) - - roots_values = roots.get_values() - value_list = np.array(roots_values) - else: - # use the predicted values - value_list = concat_output_value(network_output) - - # get last state value - if self._cfg.env_type == 'board_games' and to_play_segment[0][0] in [1, 2]: - # TODO(pu): for board_games, very important, to check - value_list = value_list.reshape(-1) * np.array( - [ - self._cfg.discount_factor ** td_steps_list[i] if int(td_steps_list[i]) % - 2 == 0 else -self._cfg.discount_factor ** td_steps_list[i] - for i in range(transition_batch_size) - ] - ) - else: - value_list = value_list.reshape(-1) * ( - np.array([self._cfg.discount_factor for _ in range(transition_batch_size)]) ** td_steps_list - ) - - value_list = value_list * np.array(value_mask) - value_list = value_list.tolist() - horizon_id, value_index = 0, 0 - - for game_segment_len_non_re, reward_list, state_index, to_play_list in zip(game_segment_lens, rewards_list, - pos_in_game_segment_list, - to_play_segment): - target_values = [] - target_rewards = [] - base_index = state_index - for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): - bootstrap_index = current_index + td_steps_list[value_index] - # for i, reward in enumerate(game.rewards[current_index:bootstrap_index]): - for i, reward in enumerate(reward_list[current_index:bootstrap_index]): - if self._cfg.env_type == 'board_games' and to_play_segment[0][0] in [1, 2]: - # TODO(pu): for board_games, very important, to check - if to_play_list[base_index] == to_play_list[i]: - value_list[value_index] += reward * self._cfg.discount_factor ** i - else: - value_list[value_index] += -reward * self._cfg.discount_factor ** i - else: - value_list[value_index] += reward * self._cfg.discount_factor ** i - horizon_id += 1 - - if current_index < game_segment_len_non_re: - target_values.append(value_list[value_index]) - target_rewards.append(reward_list[current_index]) - else: - target_values.append(0) - target_rewards.append(0.0) - # TODO: check - # target_rewards.append(reward) - value_index += 1 - - batch_rewards.append(target_rewards) - batch_target_values.append(target_values) - - batch_rewards = np.asarray(batch_rewards, dtype=object) - batch_target_values = np.asarray(batch_target_values, dtype=object) - return batch_rewards, batch_target_values - - def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: Any) -> np.ndarray: - """ - Overview: - prepare policy targets from the reanalyzed context of policies - Arguments: - - policy_re_context (:obj:`List`): List of policy context to reanalyzed - Returns: - - batch_target_policies_re - """ - if policy_re_context is None: - return [] - batch_target_policies_re = [] - - # for board games - policy_obs_list, policy_mask, pos_in_game_segment_list, batch_index_list, child_visits, game_segment_lens, action_mask_segment, \ - to_play_segment = policy_re_context # noqa - # transition_batch_size = game_segment_batch_size * (self._cfg.num_unroll_steps + 1) - transition_batch_size = len(policy_obs_list) - game_segment_batch_size = len(pos_in_game_segment_list) - - to_play, action_mask = self._preprocess_to_play_and_action_mask( - game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list - ) - - if self._cfg.model.continuous_action_space is True: - # when the action space of the environment is continuous, action_mask[:] is None. - action_mask = [ - list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size) - ] - # NOTE: in continuous action space env: we set all legal_actions as -1 - legal_actions = [ - [-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size) - ] - else: - legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)] - - with torch.no_grad(): - policy_obs_list = prepare_observation(policy_obs_list, self._cfg.model.model_type) - # split a full batch into slices of mini_infer_size: to save the GPU memory for more GPU actors - slices = int(np.ceil(transition_batch_size / self._cfg.mini_infer_size)) - network_output = [] - for i in range(slices): - beg_index = self._cfg.mini_infer_size * i - end_index = self._cfg.mini_infer_size * (i + 1) - m_obs = torch.from_numpy(policy_obs_list[beg_index:end_index]).to(self._cfg.device).float() - m_output = model.initial_inference(m_obs) - if not model.training: - # if not in training, obtain the scalars of the value/reward - [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( - [ - m_output.latent_state, - inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), - m_output.policy_logits - ] - ) - - network_output.append(m_output) - - _, reward_pool, policy_logits_pool, latent_state_roots = concat_output(network_output, data_type='muzero') - reward_pool = reward_pool.squeeze().tolist() - policy_logits_pool = policy_logits_pool.tolist() - noises = [ - np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self._cfg.model.action_space_size - ).astype(np.float32).tolist() for _ in range(transition_batch_size) - ] - if self._cfg.mcts_ctree: - # cpp mcts_tree - roots = MCTSCtree.roots(transition_batch_size, legal_actions) - roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) - # do MCTS for a new policy with the recent target model - MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play) - else: - # python mcts_tree - roots = MCTSPtree.roots(transition_batch_size, legal_actions) - roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) - # do MCTS for a new policy with the recent target model - MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play) - - roots_legal_actions_list = legal_actions - roots_distributions = roots.get_distributions() - policy_index = 0 - for state_index, game_index in zip(pos_in_game_segment_list, batch_index_list): - target_policies = [] - - for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): - distributions = roots_distributions[policy_index] - - if policy_mask[policy_index] == 0: - # NOTE: the invalid padding target policy, O is to make sure the corresponding cross_entropy_loss=0 - target_policies.append([0 for _ in range(self._cfg.model.action_space_size)]) - else: - if distributions is None: - # if at some obs, the legal_action is None, add the fake target_policy - target_policies.append( - list(np.ones(self._cfg.model.action_space_size) / self._cfg.model.action_space_size) - ) - else: - if self._cfg.env_type == 'not_board_games': - # for atari/classic_control/box2d environments that only have one player. - sum_visits = sum(distributions) - policy = [visit_count / sum_visits for visit_count in distributions] - target_policies.append(policy) - else: - # for board games that have two players and legal_actions is dy - policy_tmp = [0 for _ in range(self._cfg.model.action_space_size)] - # to make sure target_policies have the same dimension - sum_visits = sum(distributions) - policy = [visit_count / sum_visits for visit_count in distributions] - for index, legal_action in enumerate(roots_legal_actions_list[policy_index]): - policy_tmp[legal_action] = policy[index] - target_policies.append(policy_tmp) - - policy_index += 1 - - batch_target_policies_re.append(target_policies) - - batch_target_policies_re = np.array(batch_target_policies_re) - - return batch_target_policies_re - - def _compute_target_policy_non_reanalyzed( - self, policy_non_re_context: List[Any], policy_shape: Optional[int] - ) -> np.ndarray: - """ - Overview: - prepare policy targets from the non-reanalyzed context of policies - Arguments: - - policy_non_re_context (:obj:`List`): List containing: - - pos_in_game_segment_list - - child_visits - - game_segment_lens - - action_mask_segment - - to_play_segment - - policy_shape: self._cfg.model.action_space_size - Returns: - - batch_target_policies_non_re - """ - batch_target_policies_non_re = [] - if policy_non_re_context is None: - return batch_target_policies_non_re - - pos_in_game_segment_list, child_visits, game_segment_lens, action_mask_segment, to_play_segment = policy_non_re_context - game_segment_batch_size = len(pos_in_game_segment_list) - transition_batch_size = game_segment_batch_size * (self._cfg.num_unroll_steps + 1) - - to_play, action_mask = self._preprocess_to_play_and_action_mask( - game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list - ) - - if self._cfg.model.continuous_action_space is True: - # when the action space of the environment is continuous, action_mask[:] is None. - action_mask = [ - list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size) - ] - # NOTE: in continuous action space env: we set all legal_actions as -1 - legal_actions = [ - [-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size) - ] - else: - legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)] - - with torch.no_grad(): - policy_index = 0 - # 0 -> Invalid target policy for padding outside of game segments, - # 1 -> Previous target policy for game segments. - policy_mask = [] - for game_segment_len, child_visit, state_index in zip(game_segment_lens, child_visits, - pos_in_game_segment_list): - target_policies = [] - - for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): - if current_index < game_segment_len: - policy_mask.append(1) - # NOTE: child_visit is already a distribution - distributions = child_visit[current_index] - if self._cfg.env_type == 'not_board_games': - # for atari/classic_control/box2d environments that only have one player. - target_policies.append(distributions) - else: - # for board games that have two players. - policy_tmp = [0 for _ in range(policy_shape)] - for index, legal_action in enumerate(legal_actions[policy_index]): - # only the action in ``legal_action`` the policy logits is nonzero - policy_tmp[legal_action] = distributions[index] - target_policies.append(policy_tmp) - else: - # NOTE: the invalid padding target policy, O is to make sure the correspoding cross_entropy_loss=0 - policy_mask.append(0) - target_policies.append([0 for _ in range(policy_shape)]) - - policy_index += 1 - - batch_target_policies_non_re.append(target_policies) - batch_target_policies_non_re = np.asarray(batch_target_policies_non_re) - return batch_target_policies_non_re - - def update_priority(self, train_data: List[np.ndarray], batch_priorities: Any) -> None: - """ - Overview: - Update the priority of training data. - Arguments: - - train_data (:obj:`Optional[List[Optional[np.ndarray]]]`): training data to be updated priority. - - batch_priorities (:obj:`batch_priorities`): priorities to update to. - NOTE: - train_data = [current_batch, target_batch] - current_batch = [obs_list, action_list, mask_list, batch_index_list, weights, make_time_list] - """ - indices = train_data[0][3] - metas = {'make_time': train_data[0][5], 'batch_priorities': batch_priorities} - # only update the priorities for data still in replay buffer - for i in range(len(indices)): - if metas['make_time'][i] > self.clear_time: - idx, prio = indices[i], metas['batch_priorities'][i] - self.game_pos_priorities[idx] = prio diff --git a/lzero/mcts/ptree/ptree_stochastic_mz.py b/lzero/mcts/ptree/ptree_stochastic_mz.py index 5e17c8b32..b9cf679ee 100644 --- a/lzero/mcts/ptree/ptree_stochastic_mz.py +++ b/lzero/mcts/ptree/ptree_stochastic_mz.py @@ -599,8 +599,6 @@ def batch_backpropagate( """ if leaf_idx_list is None: leaf_idx_list = list(range(results.num)) - # for i in range(results.num): - # for i in leaf_idx_list: for leaf_order, i in enumerate(leaf_idx_list): # ****** expand the leaf node ****** if to_play is None: diff --git a/lzero/mcts/tree_search/mcts_ptree_stochastic.py b/lzero/mcts/tree_search/mcts_ptree_stochastic.py index 8cbd41f7a..eb6f5f4b3 100644 --- a/lzero/mcts/tree_search/mcts_ptree_stochastic.py +++ b/lzero/mcts/tree_search/mcts_ptree_stochastic.py @@ -7,7 +7,7 @@ from lzero.mcts.ptree import MinMaxStatsList from lzero.policy import InverseScalarTransform -import lzero.mcts.ptree.ptree_stochastic_mz as tree_muzero +import lzero.mcts.ptree.ptree_stochastic_mz as tree_stochastic_muzero if TYPE_CHECKING: import lzero.mcts.ptree.ptree_stochastic_mz as stochastic_mz_ptree @@ -110,7 +110,7 @@ def search( latent_states = [] # prepare a result wrapper to transport results between python and c++ parts - results = tree_muzero.SearchResults(num=num) + results = tree_stochastic_muzero.SearchResults(num=num) # latent_state_index_in_search_path: The first index of the latent state corresponding to the leaf node in latent_state_batch_in_search_path, that is, the search depth. # latent_state_index_in_batch: The second index of the latent state corresponding to the leaf node in latent_state_batch_in_search_path, i.e. the index in the batch, whose maximum is ``batch_size``. @@ -119,10 +119,10 @@ def search( MCTS stage 1: Selection Each simulation starts from the internal root state s0, and finishes when the simulation reaches a leaf node s_l. """ - # leaf_nodes, latent_state_index_in_search_path, latent_state_index_in_batch, last_actions, virtual_to_play = tree_muzero.batch_traverse( + # leaf_nodes, latent_state_index_in_search_path, latent_state_index_in_batch, last_actions, virtual_to_play = tree_stochastic_muzero.batch_traverse( # roots, pb_c_base, pb_c_init, discount_factor, min_max_stats_lst, results, copy.deepcopy(to_play) # ) - results, virtual_to_play = tree_muzero.batch_traverse( + results, virtual_to_play = tree_stochastic_muzero.batch_traverse( roots, pb_c_base, pb_c_init, discount_factor, min_max_stats_lst, results, copy.deepcopy(to_play) ) leaf_nodes, latent_state_index_in_search_path, latent_state_index_in_batch, last_actions = results.nodes, results.latent_state_index_in_search_path, results.latent_state_index_in_batch, results.last_actions @@ -137,12 +137,13 @@ def search( # only for discrete action last_actions = torch.from_numpy(np.asarray(last_actions)).to(device).long() """ - MCTS stage 2: Expansion - At the final time-step l of the simulation, the next_latent_state and reward/value_prefix are computed by the dynamics function. - Then we calculate the policy_logits and value for the leaf node (next_latent_state) by the prediction function. (aka. evaluation) - MCTS stage 3: Backup - At the end of the simulation, the statistics along the trajectory are updated. - """ + MCTS stage 2: Expansion + At the final time-step l of the simulation, the next_latent_state and reward/value_prefix are computed by the dynamics function. + Then we calculate the policy_logits and value for the leaf node (next_latent_state) by the prediction function. (aka. evaluation) + MCTS stage 3: Backup + At the end of the simulation, the statistics along the trajectory are updated. + """ + # network_output = model.recurrent_inference(latent_states, last_actions) num = len(leaf_nodes) latent_state_batch = [None] * num @@ -211,7 +212,7 @@ def process_nodes(node_indices, is_chance): value_batch_chance = np.concatenate(value_batch_chance, axis=0) reward_batch_chance = np.concatenate(reward_batch_chance, axis=0) policy_logits_batch_chance = np.concatenate(policy_logits_batch_chance, axis=0) - tree_muzero.batch_backpropagate( + tree_stochastic_muzero.batch_backpropagate( current_latent_state_index, discount_factor, reward_batch_chance, value_batch_chance, policy_logits_batch_chance, min_max_stats_lst, results, virtual_to_play, child_is_chance_batch, chance_nodes @@ -220,7 +221,7 @@ def process_nodes(node_indices, is_chance): value_batch_decision = np.concatenate(value_batch_decision, axis=0) reward_batch_decision = np.concatenate(reward_batch_decision, axis=0) policy_logits_batch_decision = np.concatenate(policy_logits_batch_decision, axis=0) - tree_muzero.batch_backpropagate( + tree_stochastic_muzero.batch_backpropagate( current_latent_state_index, discount_factor, reward_batch_decision, value_batch_decision, policy_logits_batch_decision, min_max_stats_lst, results, virtual_to_play, child_is_chance_batch, decision_nodes diff --git a/zoo/atari/config/atari_stochastic_muzero_config.py b/zoo/atari/config/atari_stochastic_muzero_config.py index bf359b360..91dfe45c7 100644 --- a/zoo/atari/config/atari_stochastic_muzero_config.py +++ b/zoo/atari/config/atari_stochastic_muzero_config.py @@ -17,12 +17,23 @@ # ============================================================== # begin of the most frequently changed config specified by the user # ============================================================== -collector_env_num = 8 -n_episode = 8 -evaluator_env_num = 3 -num_simulations = 50 -update_per_collect = 1000 -batch_size = 256 +# collector_env_num = 8 +# n_episode = 8 +# evaluator_env_num = 3 +# num_simulations = 50 +# update_per_collect = 1000 +# batch_size = 256 +# max_env_step = int(1e6) +# reanalyze_ratio = 0. +# chance_space_size = 4 + +# debug config +collector_env_num = 1 +n_episode = 1 +evaluator_env_num = 1 +num_simulations = 5 +update_per_collect = 10 +batch_size = 2 max_env_step = int(1e6) reanalyze_ratio = 0. chance_space_size = 4 diff --git a/zoo/atari/entry/atari_eval_config.py b/zoo/atari/entry/atari_eval_config.py index ba897207a..824ffd1ab 100644 --- a/zoo/atari/entry/atari_eval_config.py +++ b/zoo/atari/entry/atari_eval_config.py @@ -9,7 +9,7 @@ In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``. """ # Take the config of sampled efficientzero as an example - from atari_sampled_efficientzero_config import main_config, create_config + from zoo.atari.config.atari_sampled_efficientzero_config import main_config, create_config model_path = "/path/ckpt/ckpt_best.pth.tar" @@ -18,13 +18,13 @@ seeds = [0] num_episodes_each_seed = 1 total_test_episodes = num_episodes_each_seed * len(seeds) - create_config.env_manager.type = 'base' # Visualization requires the 'type' to be set as base - main_config.env.evaluator_env_num = 1 # Visualization requires the 'env_num' to be set as 1 + create_config.env_manager.type = 'base' # Visualization requires the 'type' to be set as base + main_config.env.evaluator_env_num = 1 # Visualization requires the 'env_num' to be set as 1 main_config.env.n_evaluator_episode = total_test_episodes - main_config.env.render_mode_human = True # Whether to enable real-time rendering - main_config.env.save_video = True # Whether to save the video, if save the video render_mode_human must to be True + main_config.env.render_mode_human = True # Whether to enable real-time rendering + main_config.env.save_video = True # Whether to save the video, if save the video render_mode_human must to be True main_config.env.save_path = '../config/' - main_config.env.eval_max_episode_steps=int(1e3) # Adjust according to different environments + main_config.env.eval_max_episode_steps = int(1e3) # Adjust according to different environments for seed in seeds: returns_mean, returns = eval_muzero( diff --git a/zoo/board_games/tictactoe/envs/tictactoe_env.py b/zoo/board_games/tictactoe/envs/tictactoe_env.py index 303365dfe..674b58310 100644 --- a/zoo/board_games/tictactoe/envs/tictactoe_env.py +++ b/zoo/board_games/tictactoe/envs/tictactoe_env.py @@ -123,12 +123,7 @@ def reset(self, start_player_index=0, init_state=None): self.board = np.zeros((self.board_size, self.board_size), dtype="int32") action_mask = np.zeros(self.total_num_actions, 'int8') - - # TODO(pu): debug - legal_actions = _legal_actions_func_lru(tuple(map(tuple, self.board))) - action_mask[legal_actions] = 1 - - # action_mask[self.legal_actions] = 1 + action_mask[self.legal_actions] = 1 if self.battle_mode == 'play_with_bot_mode' or self.battle_mode == 'eval_mode': # In ``play_with_bot_mode`` and ``eval_mode``, we need to set the "to_play" parameter in the "obs" dict to -1, diff --git a/zoo/classic_control/cartpole/config/cartpole_stochastic_muzero_config.py b/zoo/classic_control/cartpole/config/cartpole_stochastic_muzero_config.py index 7a47cb95f..b8499755b 100644 --- a/zoo/classic_control/cartpole/config/cartpole_stochastic_muzero_config.py +++ b/zoo/classic_control/cartpole/config/cartpole_stochastic_muzero_config.py @@ -38,7 +38,8 @@ discrete_action_encoding_type='one_hot', norm_type='BN', ), - mcts_ctree=True, + # mcts_ctree=True, + mcts_ctree=False, gumbel_algo=False, cuda=True, env_type='not_board_games', diff --git a/zoo/game_2048/config/muzero_2048_config.py b/zoo/game_2048/config/muzero_2048_config.py index 73ddb8cce..64fbe475e 100644 --- a/zoo/game_2048/config/muzero_2048_config.py +++ b/zoo/game_2048/config/muzero_2048_config.py @@ -9,7 +9,7 @@ collector_env_num = 8 n_episode = 8 evaluator_env_num = 3 -num_simulations = 50 +num_simulations = 100 update_per_collect = 200 batch_size = 512 max_env_step = int(5e6) @@ -31,7 +31,7 @@ # ============================================================== atari_muzero_config = dict( - exp_name=f'data_mz_ctree/game_2048_nct-{num_of_possible_chance_tile}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_bs{batch_size}_adam-wd0_seed0', + exp_name=f'data_mz_ctree/game_2048_npct-{num_of_possible_chance_tile}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_bs{batch_size}_sslw2_seed0', env=dict( stop_value=int(1e6), env_name=env_name, @@ -72,7 +72,7 @@ weight_decay=1e-4, num_simulations=num_simulations, reanalyze_ratio=reanalyze_ratio, - ssl_loss_weight=0, # default is 0 + ssl_loss_weight=2, # default is 0 n_episode=n_episode, eval_freq=int(2e3), replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. diff --git a/zoo/game_2048/config/stochastic_muzero_2048_config.py b/zoo/game_2048/config/stochastic_muzero_2048_config.py index 15ae2999a..cdd4f37f8 100644 --- a/zoo/game_2048/config/stochastic_muzero_2048_config.py +++ b/zoo/game_2048/config/stochastic_muzero_2048_config.py @@ -6,8 +6,7 @@ # ============================================================== env_name = 'game_2048' action_space_size = 4 -use_ture_chance_label_in_chance_encoder = True -# use_ture_chance_label_in_chance_encoder = False +use_ture_chance_label_in_chance_encoder = True # option: {True, False} collector_env_num = 8 n_episode = 8 evaluator_env_num = 3 @@ -33,7 +32,7 @@ # ============================================================== game_2048_stochastic_muzero_config = dict( - exp_name=f'data_stochastic_mz_ctree/game_2048_nct-{num_of_possible_chance_tile}_stochastic_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_bs{batch_size}_chance-{use_ture_chance_label_in_chance_encoder}-{chance_space_size}_seed0', + exp_name=f'data_stochastic_mz_ctree/game_2048_npct-{num_of_possible_chance_tile}_stochastic_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_bs{batch_size}_chance-{use_ture_chance_label_in_chance_encoder}_sslw2_seed0', env=dict( stop_value=int(1e6), env_name=env_name, @@ -76,10 +75,10 @@ weight_decay=1e-4, num_simulations=num_simulations, reanalyze_ratio=reanalyze_ratio, - ssl_loss_weight=0, # default is 0 + ssl_loss_weight=2, # default is 0 n_episode=n_episode, eval_freq=int(2e3), - replay_buffer_size=int(2e7), # the size/capacity of replay_buffer, in the terms of transitions. + replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, ), diff --git a/zoo/game_2048/entry/rule_based_2048_config.py b/zoo/game_2048/entry/rule_based_2048_config.py index 651adc897..7d48065ac 100644 --- a/zoo/game_2048/entry/rule_based_2048_config.py +++ b/zoo/game_2048/entry/rule_based_2048_config.py @@ -106,9 +106,9 @@ def move(grid: np.array, action: int, game_score: int = 0) -> Tuple[np.array, bo if action == 0: grid = np.rot90(grid) elif action == 1: - grid = np.rot90(grid, k=3) - elif action == 2: grid = np.rot90(grid, k=2) + elif action == 2: + grid = np.rot90(grid, k=3) # simple move for i in range(4): for j in range(3): @@ -132,9 +132,9 @@ def move(grid: np.array, action: int, game_score: int = 0) -> Tuple[np.array, bo if action == 0: grid = np.rot90(grid, k=3) elif action == 1: - grid = np.rot90(grid) - elif action == 2: grid = np.rot90(grid, k=2) + elif action == 2: + grid = np.rot90(grid) move_flag = np.any(old_grid != grid) return grid, move_flag, game_score @@ -157,13 +157,13 @@ def generate(grid: np.array) -> np.array: env_name="game_2048", save_replay=False, replay_format='mp4', - replay_name_suffix='ns100_s1', + replay_name_suffix='test', replay_path=None, + render_real_time=False, act_scale=True, channel_last=True, - obs_type='array', - raw_reward_type='raw', # 'merged_tiles_plus_log_max_tile_num' - reward_type='raw', # 'merged_tiles_plus_log_max_tile_num' + obs_type='array', # options=['raw_observation', 'dict_observation', 'array'] + reward_type='raw', # options=['raw', 'merged_tiles_plus_log_max_tile_num'] reward_normalize=False, reward_norm_scale=100, max_tile=int(2 ** 16), @@ -173,6 +173,9 @@ def generate(grid: np.array) -> np.array: is_collect=False, ignore_legal_actions=True, need_flatten=False, + num_of_possible_chance_tile=2, + possible_tiles=np.array([2, 4]), + tile_probabilities=np.array([0.9, 0.1]), )) if __name__ == "__main__": @@ -182,21 +185,16 @@ def generate(grid: np.array) -> np.array: game_2048_env.render() step = 0 while True: - print('=' * 20) + print('=' * 40) grid = obs.astype(np.int64) - # action = game_2048_env.human_to_action() - action = game_2048_env.random_action() - # action = rule_based_search(grid) - if action == 1: - action = 2 - elif action == 2: - action = 1 + # action = game_2048_env.human_to_action() # which obtain about 10000 score + # action = game_2048_env.random_action() # which obtain about 1000 score + action = rule_based_search(grid) # which obtain about 58536 score try: obs, reward, done, info = game_2048_env.step(action) except Exception as e: print(f'Exception: {e}') print('total_step_number: {}'.format(step)) - game_2048_env.save_render_gif(replay_name_suffix='bot') break step += 1 print(f"step: {step}, action: {action}, reward: {reward}, raw_reward: {info['raw_reward']}") diff --git a/zoo/game_2048/envs/game_2048_env.py b/zoo/game_2048/envs/game_2048_env.py index 79b346f80..9884f1683 100644 --- a/zoo/game_2048/envs/game_2048_env.py +++ b/zoo/game_2048/envs/game_2048_env.py @@ -1,5 +1,4 @@ import copy -import itertools import logging import sys from typing import List @@ -32,7 +31,7 @@ class Game2048Env(gym.Env): obs_type='raw_observation', # options=['raw_observation', 'dict_observation', 'array'] reward_normalize=False, reward_norm_scale=100, - reward_type='raw', # 'merged_tiles_plus_log_max_tile_num' + reward_type='raw', # options=['raw', 'merged_tiles_plus_log_max_tile_num'] max_tile=int(2 ** 16), # 2**11=2048, 2**16=65536 delay_reward_step=0, prob_random_agent=0., @@ -92,7 +91,6 @@ def __init__(self, cfg: dict) -> None: self._action_space = spaces.Discrete(4) self._observation_space = spaces.Box(0, 1, (self.w, self.h, self.squares), dtype=int) self._reward_range = (0., self.max_tile) - self.set_illegal_move_reward(0.) # for render self.grid_size = 70 @@ -103,35 +101,39 @@ def __init__(self, cfg: dict) -> None: self.possible_tiles = cfg.possible_tiles self.tile_probabilities = cfg.tile_probabilities if self.num_of_possible_chance_tile > 2: - self.possible_tiles = np.array([2**(i+1) for i in range(self.num_of_possible_chance_tile)]) - self.tile_probabilities = np.array([1/self.num_of_possible_chance_tile for _ in range(self.num_of_possible_chance_tile)]) + self.possible_tiles = np.array([2 ** (i + 1) for i in range(self.num_of_possible_chance_tile)]) + self.tile_probabilities = np.array( + [1 / self.num_of_possible_chance_tile for _ in range(self.num_of_possible_chance_tile)]) assert self.possible_tiles.shape[0] == self.tile_probabilities.shape[0] assert np.sum(self.tile_probabilities) == 1 - def reset(self): + def reset(self, init_board=None, add_random_tile_flag=True): """Reset the game board-matrix and add 2 tiles.""" self.episode_length = 0 - self.board = np.zeros((self.h, self.w), np.int32) + self.add_random_tile_flag = add_random_tile_flag + if init_board is not None: + self.board = copy.deepcopy(init_board) + else: + self.board = np.zeros((self.h, self.w), np.int32) + # Add two tiles at the start of the game + for _ in range(2): + if self.num_of_possible_chance_tile > 2: + self.add_random_tile(self.possible_tiles, self.tile_probabilities) + elif self.num_of_possible_chance_tile == 2: + self.add_random_2_4_tile() + self.episode_return = 0 self._final_eval_reward = 0.0 self.should_done = False - - logging.debug("Adding tiles") - # TODO(pu): why add_tiles twice? - if self.num_of_possible_chance_tile > 2: - self.add_random_tile(self.possible_tiles, self.tile_probabilities) - self.add_random_tile(self.possible_tiles, self.tile_probabilities) - elif self.num_of_possible_chance_tile == 2: - self.add_random_2_4_tile() - self.add_random_2_4_tile() - + # Create a mask for legal actions action_mask = np.zeros(4, 'int8') action_mask[self.legal_actions] = 1 - observation = encode_board(self.board) - observation = observation.astype(np.float32) + # Encode the board, ensure correct datatype and shape + observation = encode_board(self.board).astype(np.float32) assert observation.shape == (4, 4, 16) + # Reshape or transpose the observation as per the requirement if not self.channel_last: # move channel dim to fist axis # (W, H, C) -> (C, W, H) @@ -140,27 +142,48 @@ def reset(self): if self.need_flatten: observation = observation.reshape(-1) + # Based on the observation type, create the appropriate observation object if self.obs_type == 'dict_observation': - observation = {'observation': observation, 'action_mask': action_mask, 'to_play': -1, 'chance': self.chance} + observation = { + 'observation': observation, + 'action_mask': action_mask, + 'to_play': -1, + 'chance': self.chance + } elif self.obs_type == 'array': observation = self.board - else: - observation = observation + + # Render the game if the replay is to be saved if self.save_replay: self.render(mode='rgb_array_render') + return observation def step(self, action): - """Perform one step of the game. This involves moving and adding a new tile.""" + """ + Overview: + Perform one step of the game. This involves making a move, adding a new tile, and updating the game state. + New tile could be added randomly or from the tile probabilities. + The rewards are calculated based on the game configuration ('merged_tiles_plus_log_max_tile_num' or 'raw'). + The observations are also returned based on the game configuration ('dict_observation', 'array', or 'raw'). + Arguments: + - action (:obj:`int`): The action to be performed. + Returns: + - BaseEnvTimestep: Contains the new state observation, reward, and other game information. + """ + + # Increment the total episode length self.episode_length += 1 + # Check if the action is legal, otherwise choose a random legal action if action not in self.legal_actions: logging.warning( - f"You input illegal action: {action}, the legal_actions are {self.legal_actions}. " - f"Now we randomly choice a action from self.legal_actions." + f"Illegal action: {action}. Legal actions: {self.legal_actions}. " + "Choosing a random action from legal actions." ) action = np.random.choice(self.legal_actions) + # Calculate the reward differently based on the reward type if self.reward_type == 'merged_tiles_plus_log_max_tile_num': empty_num1 = len(self.get_empty_location()) raw_reward = float(self.move(action)) @@ -173,38 +196,40 @@ def step(self, action): reward_merged_tiles_plus_log_max_tile_num += np.log2(max_tile_num) * 0.1 self.max_tile_num = max_tile_num + # Update total reward and add new tile self.episode_return += raw_reward assert raw_reward <= 2 ** (self.w * self.h) - if self.num_of_possible_chance_tile > 2: - self.add_random_tile(self.possible_tiles, self.tile_probabilities) - elif self.num_of_possible_chance_tile == 2: - self.add_random_2_4_tile() - done = self.is_end() + if self.add_random_tile_flag: + if self.num_of_possible_chance_tile > 2: + self.add_random_tile(self.possible_tiles, self.tile_probabilities) + elif self.num_of_possible_chance_tile == 2: + self.add_random_2_4_tile() + + # Check if the game has ended + done = self.is_done() + + # Convert rewards to float if self.reward_type == 'merged_tiles_plus_log_max_tile_num': reward_merged_tiles_plus_log_max_tile_num = float(reward_merged_tiles_plus_log_max_tile_num) elif self.reward_type == 'raw': raw_reward = float(raw_reward) + # End the game if the maximum steps have been reached if self.episode_length >= self.max_episode_steps: - # print("episode_length: {}".format(self.episode_length)) done = True + # Prepare the game state observation observation = encode_board(self.board) observation = observation.astype(np.float32) - assert observation.shape == (4, 4, 16) - if not self.channel_last: - # move channel dim to fist axis - # (W, H, C) -> (C, W, H) - # e.g. (4, 4, 16) -> (16, 4, 4) observation = np.transpose(observation, [2, 0, 1]) - if self.need_flatten: observation = observation.reshape(-1) action_mask = np.zeros(4, 'int8') action_mask[self.legal_actions] = 1 + # Return the observation based on the observation type if self.obs_type == 'dict_observation': observation = {'observation': observation, 'action_mask': action_mask, 'to_play': -1, 'chance': self.chance} elif self.obs_type == 'array': @@ -212,6 +237,7 @@ def step(self, action): else: observation = observation + # Normalize the reward if necessary if self.reward_normalize: reward_normalize = raw_reward / self.reward_norm_scale reward = reward_normalize @@ -220,174 +246,183 @@ def step(self, action): self._final_eval_reward += raw_reward + # Convert the reward to ndarray if self.reward_type == 'merged_tiles_plus_log_max_tile_num': reward = to_ndarray([reward_merged_tiles_plus_log_max_tile_num]).astype(np.float32) elif self.reward_type == 'raw': reward = to_ndarray([reward]).astype(np.float32) - info = {"raw_reward": raw_reward, "current_max_tile_num": self.highest()} + # Prepare information to return + info = {"raw_reward": raw_reward, "current_max_tile_num": self.highest()} if self.save_replay: self.render(mode='rgb_array_render') + # If the game has ended, save additional information and the replay if necessary if done: info['eval_episode_return'] = self._final_eval_reward if self.save_replay: - self.save_render_output(replay_name_suffix=self.replay_name_suffix, replay_path=self.replay_path, format=self.replay_format) + self.save_render_output(replay_name_suffix=self.replay_name_suffix, replay_path=self.replay_path, + format=self.replay_format) return BaseEnvTimestep(observation, reward, done, info) def move(self, direction, trial=False): """ Overview: - Perform one move of the game. Shift things to one side then, - combine. directions 0, 1, 2, 3 are up, right, down, left. - Returns the reward that [would have] got. + Perform one move in the game. The game board can be shifted in one of four directions: up (0), right (1), down (2), or left (3). + This method manages the shifting process and combines similar adjacent elements. It also returns the reward generated from the move. Arguments: - - direction (:obj:`int`): The direction to move. - - trial (:obj:`bool`): Whether this is a trial move. + - direction (:obj:`int`): The direction of the move. + - trial (:obj:`bool`): If true, this move is only simulated and does not change the actual game state. """ + # TODO(pu): different transition dynamics + # Logging the direction of the move if not a trial if not trial: - if direction == 0: - logging.debug("Up") - elif direction == 1: - logging.debug("Right") - elif direction == 2: - logging.debug("Down") - elif direction == 3: - logging.debug("Left") - - changed = False + logging.debug(["Up", "Right", "Down", "Left"][int(direction)]) + move_reward = 0 - dir_div_two = int(direction / 2) - dir_mod_two = int(direction % 2) - # 0 for towards up or left, 1 for towards bottom or right - shift_direction = dir_mod_two ^ dir_div_two + # Calculate merge direction of the shift (0 for up/left, 1 for down/right) based on the input direction + merge_direction = 0 if direction in [0, 3] else 1 # Construct a range for extracting row/column into a list - rx = list(range(self.w)) - ry = list(range(self.h)) + range_x = list(range(self.w)) + range_y = list(range(self.h)) - if dir_mod_two == 0: - # Up or down, split into columns + # If direction is up or down, process the board column by column + if direction in [0, 2]: for y in range(self.h): - old = [self.board[x, y] for x in rx] - (new, ms) = self.shift(old, shift_direction) - move_reward += ms - if old != new: - changed = True - if not trial: - for x in rx: - self.board[x, y] = new[x] - + old_col = [self.board[x, y] for x in range_x] + new_col, reward = self.shift(old_col, merge_direction) + move_reward += reward + if old_col != new_col and not trial: # Update the board if it's not a trial move + for x in range_x: + self.board[x, y] = new_col[x] + # If direction is left or right, process the board row by row else: - # Left or right, split into rows for x in range(self.w): - old = [self.board[x, y] for y in ry] - (new, ms) = self.shift(old, shift_direction) - move_reward += ms - if old != new: - changed = True - if not trial: - for y in ry: - self.board[x, y] = new[y] - - # TODO(pu): different transition dynamics - # if not changed: - # raise IllegalMove + old_row = [self.board[x, y] for y in range_y] + new_row, reward = self.shift(old_row, merge_direction) + move_reward += reward + if old_row != new_row and not trial: # Update the board if it's not a trial move + for y in range_y: + self.board[x, y] = new_row[y] return move_reward - def set_illegal_move_reward(self, reward): + def shift(self, row, merge_direction): """ - Define the reward/penalty for performing an illegal move. Also need to update the reward range for this. + Overview: + This method shifts the elements in a given row or column of the 2048 board in a specified direction. + It performs three main operations: removal of zeroes, combination of similar elements, and filling up the + remaining spaces with zeroes. The direction of shift can be either left (0) or right (1). + Arguments: + - row: A list of integers representing a row or a column in the 2048 board. + - merge_direction: An integer that dictates the direction of merge. It can be either 0 or 1. + - 0: The elements in the 'row' will be merged towards left/up. + - 1: The elements in the 'row' will be merged towards right/down. + Returns: + - combined_row: A list of integers of the same length as 'row' after shifting and merging. + - move_reward: The reward gained from combining similar elements in 'row'. It is the sum of all new + combinations. + Note: + This method assumes that the input 'row' is a list of integers and 'merge_direction' is either 0 or 1. """ - # Guess that the maximum reward is also 2**squares though you'll probably never get that. - # (assume that illegal move reward is the lowest value that can be returned - # TODO: check that this is correct - self.illegal_move_reward = reward - self.reward_range = (self.illegal_move_reward, float(2 ** self.squares)) - def render(self, mode='human'): - if mode == 'rgb_array_render': - grey = (128, 128, 128) - grid_size = self.grid_size + # Remove the zero elements from the row and store it in a new list. + non_zero_row = [i for i in row if i != 0] - # Render with Pillow - pil_board = Image.new("RGB", (grid_size * 4, grid_size * 4)) - draw = ImageDraw.Draw(pil_board) - draw.rectangle([0, 0, 4 * grid_size, 4 * grid_size], grey) - fnt_path = fm.findfont(fm.FontProperties(family='DejaVu Sans')) - fnt = ImageFont.truetype(fnt_path, 30) + # Determine the start, stop, and step values based on the direction of shift. + # If the direction is left (0), we start at the first element and move forwards. + # If the direction is right (1), we start at the last element and move backwards. + start, stop, step = (0, len(non_zero_row), 1) if merge_direction == 0 else (len(non_zero_row) - 1, -1, -1) - for y in range(4): - for x in range(4): - o = self.board[y, x] - if o: - self.draw_tile(draw, x, y, o, fnt) + # Call the combine function to merge the adjacent, same elements in the row. + combined_row, move_reward = self.combine(non_zero_row, start, stop, step) - # Instead of returning the image, we display it using pyplot - if self.render_real_time: - plt.imshow(np.asarray(pil_board)) - plt.draw() - # plt.pause(0.001) - # Append the frame to frames for gif - self.frames.append(np.asarray(pil_board)) - elif mode == 'human': - s = 'Current Return: {}, '.format(self.episode_return) - s += 'Current Highest Tile number: {}\n'.format(self.highest()) - npa = np.array(self.board) - grid = npa.reshape((self.size, self.size)) - s += "{}\n".format(grid) - sys.stdout.write(s) - return sys.stdout + if merge_direction == 1: + # If direction is 'right'/'down', reverse the row + combined_row = combined_row[::-1] - def draw_tile(self, draw, x, y, o, fnt): - grid_size = self.grid_size - white = (255, 255, 255) - tile_colour_map = { - 0: (204, 192, 179), - 2: (238, 228, 218), - 4: (237, 224, 200), - 8: (242, 177, 121), - 16: (245, 149, 99), - 32: (246, 124, 95), - 64: (246, 94, 59), - 128: (237, 207, 114), - 256: (237, 204, 97), - 512: (237, 200, 80), - 1024: (237, 197, 63), - 2048: (237, 194, 46), - 4096: (237, 194, 46), - 8192: (237, 194, 46), - 16384: (237, 194, 46), - } - if o: - draw.rectangle([x * grid_size, y * grid_size, (x + 1) * grid_size, (y + 1) * grid_size], - tile_colour_map[o]) - bbox = draw.textbbox((x, y), str(o), font=fnt) - text_x_size, text_y_size = bbox[2] - bbox[0], bbox[3] - bbox[1] - draw.text((x * grid_size + (grid_size - text_x_size) // 2, - y * grid_size + (grid_size - text_y_size) // 2), str(o), font=fnt, fill=white) - # assert text_x_size < grid_size - # assert text_y_size < grid_size + # Fill up the remaining spaces in the row with 0, if any. + if merge_direction == 0: + combined_row += [0] * (len(row) - len(combined_row)) + elif merge_direction == 1: + combined_row = [0] * (len(row) - len(combined_row)) + combined_row - def save_render_output(self, replay_name_suffix: str = '', replay_path=None, format='gif'): - # At the end of the episode, save the frames - if replay_path is None: - filename = f'game_2048_{replay_name_suffix}.{format}' - else: - filename = f'{replay_path}.{format}' + return combined_row, move_reward - if format == 'gif': - imageio.mimsave(filename, self.frames, 'GIF') - elif format == 'mp4': - imageio.mimsave(filename, self.frames, fps=30, codec='mpeg4') + def combine(self, row, start, stop, step): + """ + Overview: + Combine similar adjacent elements in the row, starting from the specified start index, + ending at the stop index, and moving in the direction indicated by the step. The function + also calculates the reward as the sum of all combined elements. + """ - else: - raise ValueError("Unsupported format: {}".format(format)) + # Initialize the reward for this move as 0. + move_reward = 0 - logging.info("Saved output to {}".format(filename)) - self.frames = [] + # Initialize the list to store the row after combining same elements. + combined_row = [] + + # Initialize a flag to indicate whether the next element should be skipped. + skip_next = False + + # Iterate over the elements in the row based on the start, stop, and step values. + for i in range(start, stop, step): + # If the next element should be skipped, reset the flag and continue to the next iteration. + if skip_next: + skip_next = False + continue + + # If the current element and the next element are the same, combine them. + if i + step != stop and row[i] == row[i + step]: + combined_row.append(row[i] * 2) + move_reward += row[i] * 2 + # Set the flag to skip the next element in the next iteration. + skip_next = True + else: + # If the current element and the next element are not the same, just append the current element to the result. + combined_row.append(row[i]) + + return combined_row, move_reward + + @property + def legal_actions(self): + """ + Overview: + Return the legal actions for the current state. A move is considered legal if it changes the state of the board. + """ + + if self.ignore_legal_actions: + return [0, 1, 2, 3] + + legal_actions = [] + + # For each direction, simulate a move. If the move changes the board, add the direction to the list of legal actions + for direction in range(4): + # Calculate merge direction of the shift (0 for up/left, 1 for down/right) based on the input direction + merge_direction = 0 if direction in [0, 3] else 1 + + range_x = list(range(self.w)) + range_y = list(range(self.h)) + + if direction % 2 == 0: + for y in range(self.h): + old_col = [self.board[x, y] for x in range_x] + new_col, _ = self.shift(old_col, merge_direction) + if old_col != new_col: + legal_actions.append(direction) + break # As soon as we know the move is legal, we can stop checking + else: + for x in range(self.w): + old_row = [self.board[x, y] for y in range_y] + new_row, _ = self.shift(old_row, merge_direction) + if old_row != new_row: + legal_actions.append(direction) + break # As soon as we know the move is legal, we can stop checking + + return legal_actions # Implementation of game logic for 2048 def add_random_2_4_tile(self): @@ -414,7 +449,8 @@ def add_random_2_4_tile(self): self.board[empty[0], empty[1]] = tile_val - def add_random_tile(self, possible_tiles: np.array = np.array([2, 4]), tile_probabilities: np.array = np.array([0.9, 0.1])): + def add_random_tile(self, possible_tiles: np.array = np.array([2, 4]), + tile_probabilities: np.array = np.array([0.9, 0.1])): """Add a tile with a value from possible_tiles array according to given probabilities.""" if len(possible_tiles) != len(tile_probabilities): raise ValueError("Length of possible_tiles and tile_probabilities must be the same") @@ -436,6 +472,7 @@ def add_random_tile(self, possible_tiles: np.array = np.array([2, 4]), tile_prob self.chance = tile_idx * 16 + 4 * empty[0] + empty[1] self.board[empty[0], empty[1]] = tile_val + def get_empty_location(self): """Return a 2d numpy array with the location of empty squares.""" return np.argwhere(self.board == 0) @@ -444,102 +481,7 @@ def highest(self): """Report the highest tile on the board.""" return np.max(self.board) - @property - def legal_actions(self): - """ - Overview: - Return the legal actions for the current state. - Arguments: - - None - Returns: - - legal_actions (:obj:`list`): The legal actions. - """ - if self.ignore_legal_actions: - return [0, 1, 2, 3] - legal_actions = [] - for direction in range(4): - changed = False - move_reward = 0 - dir_div_two = int(direction / 2) - dir_mod_two = int(direction % 2) - # 0 for towards up or left, 1 for towards bottom or right - shift_direction = dir_mod_two ^ dir_div_two - - # Construct a range for extracting row/column into a list - rx = list(range(self.w)) - ry = list(range(self.h)) - - if dir_mod_two == 0: - # Up or down, split into columns - for y in range(self.h): - old = [self.board[x, y] for x in rx] - (new, move_reward_tmp) = self.shift(old, shift_direction) - move_reward += move_reward_tmp - if old != new: - changed = True - else: - # Left or right, split into rows - for x in range(self.w): - old = [self.board[x, y] for y in ry] - (new, move_reward_tmp) = self.shift(old, shift_direction) - move_reward += move_reward_tmp - if old != new: - changed = True - - if changed: - legal_actions.append(direction) - - return legal_actions - - def combine(self, shifted_row): - """ - Overview: - Combine same tiles when moving to one side. This function always - shifts towards the left. Also count the reward of combined tiles. - """ - move_reward = 0 - combined_row = [0] * self.size - skip = False - output_index = 0 - for p in pairwise(shifted_row): - if skip: - skip = False - continue - combined_row[output_index] = p[0] - if p[0] == p[1]: - combined_row[output_index] += p[1] - move_reward += p[0] + p[1] - # Skip the next thing in the list. - skip = True - output_index += 1 - if shifted_row and not skip: - combined_row[output_index] = shifted_row[-1] - - return combined_row, move_reward - - def shift(self, row, direction): - """Shift one row left (direction == 0) or right (direction == 1), combining if required.""" - length = len(row) - assert length == self.size - # assert direction == 0 or direction == 1 - - # Shift all non-zero digits up - shifted_row = [i for i in row if i != 0] - - # Reverse list to handle shifting to the right - if direction: - shifted_row.reverse() - - (combined_row, move_reward) = self.combine(shifted_row) - - # Reverse list to handle shifting to the right - if direction: - combined_row.reverse() - - assert len(combined_row) == self.size - return combined_row, move_reward - - def is_end(self): + def is_done(self): """Has the game ended. Game ends if there is a tile equal to the limit or there are no legal moves. If there are empty spaces then there must be legal moves.""" @@ -575,6 +517,113 @@ def random_action(self) -> np.ndarray: random_action = to_ndarray([random_action], dtype=np.int64) return random_action + def human_to_action(self): + """ + Overview: + For multiplayer games, ask the user for a legal action + and return the corresponding action number. + Returns: + An integer from the action space. + """ + # print(self.board) + while True: + try: + action = int( + input( + f"Enter the action (0, 1, 2, or 3, ) to play: " + ) + ) + if action in self.legal_actions: + break + else: + print("Wrong input, try again") + except KeyboardInterrupt: + print("exit") + sys.exit(0) + return action + + def render(self, mode='human'): + if mode == 'rgb_array_render': + grey = (128, 128, 128) + grid_size = self.grid_size + + # Render with Pillow + pil_board = Image.new("RGB", (grid_size * 4, grid_size * 4)) + draw = ImageDraw.Draw(pil_board) + draw.rectangle([0, 0, 4 * grid_size, 4 * grid_size], grey) + fnt_path = fm.findfont(fm.FontProperties(family='DejaVu Sans')) + fnt = ImageFont.truetype(fnt_path, 30) + + for y in range(4): + for x in range(4): + o = self.board[y, x] + if o: + self.draw_tile(draw, x, y, o, fnt) + + # Instead of returning the image, we display it using pyplot + if self.render_real_time: + plt.imshow(np.asarray(pil_board)) + plt.draw() + # plt.pause(0.001) + # Append the frame to frames for gif + self.frames.append(np.asarray(pil_board)) + elif mode == 'human': + s = 'Current Return: {}, '.format(self.episode_return) + s += 'Current Highest Tile number: {}\n'.format(self.highest()) + npa = np.array(self.board) + grid = npa.reshape((self.size, self.size)) + s += "{}\n".format(grid) + sys.stdout.write(s) + return sys.stdout + + def draw_tile(self, draw, x, y, o, fnt): + grid_size = self.grid_size + white = (255, 255, 255) + tile_colour_map = { + 0: (204, 192, 179), + 2: (238, 228, 218), + 4: (237, 224, 200), + 8: (242, 177, 121), + 16: (245, 149, 99), + 32: (246, 124, 95), + 64: (246, 94, 59), + 128: (237, 207, 114), + 256: (237, 204, 97), + 512: (237, 200, 80), + 1024: (237, 197, 63), + 2048: (237, 194, 46), + 4096: (237, 194, 46), + 8192: (237, 194, 46), + 16384: (237, 194, 46), + } + if o: + draw.rectangle([x * grid_size, y * grid_size, (x + 1) * grid_size, (y + 1) * grid_size], + tile_colour_map[o]) + bbox = draw.textbbox((x, y), str(o), font=fnt) + text_x_size, text_y_size = bbox[2] - bbox[0], bbox[3] - bbox[1] + draw.text((x * grid_size + (grid_size - text_x_size) // 2, + y * grid_size + (grid_size - text_y_size) // 2), str(o), font=fnt, fill=white) + # assert text_x_size < grid_size + # assert text_y_size < grid_size + + def save_render_output(self, replay_name_suffix: str = '', replay_path=None, format='gif'): + # At the end of the episode, save the frames + if replay_path is None: + filename = f'game_2048_{replay_name_suffix}.{format}' + else: + filename = f'{replay_path}.{format}' + + if format == 'gif': + imageio.mimsave(filename, self.frames, 'GIF') + elif format == 'mp4': + imageio.mimsave(filename, self.frames, fps=30, codec='mpeg4') + + else: + raise ValueError("Unsupported format: {}".format(format)) + + logging.info("Saved output to {}".format(filename)) + self.frames = [] + @property def observation_space(self) -> gym.spaces.Space: return self._observation_space @@ -591,7 +640,7 @@ def reward_space(self) -> gym.spaces.Space: def create_collector_env_cfg(cfg: dict) -> List[dict]: collector_env_num = cfg.pop('collector_env_num') cfg = copy.deepcopy(cfg) - # when collect data, sometimes we need to normalize the reward + # when in collect phase, sometimes we need to normalize the reward # reward_normalize is determined by the config. cfg.is_collect = True return [cfg for _ in range(collector_env_num)] @@ -600,13 +649,13 @@ def create_collector_env_cfg(cfg: dict) -> List[dict]: def create_evaluator_env_cfg(cfg: dict) -> List[dict]: evaluator_env_num = cfg.pop('evaluator_env_num') cfg = copy.deepcopy(cfg) - # when evaluate, we don't need to normalize the reward. + # when in evaluate phase, we don't need to normalize the reward. cfg.reward_normalize = False cfg.is_collect = False return [cfg for _ in range(evaluator_env_num)] def __repr__(self) -> str: - return "LightZero 2048 Env." + return "LightZero game 2048 Env." def encode_board(flat_board, num_of_template_tiles=16): @@ -636,18 +685,3 @@ def encode_board(flat_board, num_of_template_tiles=16): one_hot_board = (layered_board == tile_values).astype(int) return one_hot_board - - -def pairwise(iterable): - """s -> (s0,s1), (s1,s2), (s2, s3), ...""" - a, b = itertools.tee(iterable) - next(b, None) - return zip(a, b) - - -class IllegalMove(Exception): - pass - - -class IllegalActionError(Exception): - pass diff --git a/zoo/game_2048/envs/test_game_2048_env.py b/zoo/game_2048/envs/test_game_2048_env.py index 98d6c5bb8..441d95280 100644 --- a/zoo/game_2048/envs/test_game_2048_env.py +++ b/zoo/game_2048/envs/test_game_2048_env.py @@ -11,15 +11,17 @@ def env(): # Configuration for the Game2048 environment cfg = EasyDict(dict( env_name="game_2048", - save_replay_gif=False, - replay_path_gif=None, + save_replay=False, + replay_format='gif', + replay_name_suffix='eval', replay_path=None, + render_real_time=False, act_scale=True, channel_last=True, obs_type='raw_observation', # options=['raw_observation', 'dict_observation', 'array'] + reward_type='raw', # options=['raw', 'merged_tiles_plus_log_max_tile_num'] reward_normalize=False, reward_norm_scale=100, - reward_type='raw', # 'merged_tiles_plus_log_max_tile_num' max_tile=int(2 ** 16), # 2**11=2048, 2**16=65536 delay_reward_step=0, prob_random_agent=0., @@ -27,6 +29,9 @@ def env(): is_collect=True, ignore_legal_actions=True, need_flatten=False, + num_of_possible_chance_tile=2, + possible_tiles=np.array([2, 4]), + tile_probabilities=np.array([0.9, 0.1]), )) return Game2048Env(cfg) @@ -46,7 +51,7 @@ def test_reset(env): # Test the step method of the Game2048 environment. # Ensure that the shape of the observation, the type of the reward, # the type of the done flag and the type of the info are as expected. -def test_step(env): +def test_step_shape(env): env.reset() obs, reward, done, info = env.step(1) assert obs.shape == (4, 4, 16) @@ -61,10 +66,96 @@ def test_render(env): env.reset() env.render(mode='human') env.render(mode='rgb_array_render') - env.save_render_gif() + # Test the seed method of the Game2048 environment. # Ensure that the random seed is set correctly. def test_seed(env): env.seed(0) assert env.np_random.randn() != np.random.randn() + + +def test_step_action_case1(env): + init_board = np.array([[8, 4, 0, 0], + [2, 0, 0, 0], + [2, 0, 0, 0], + [2, 4, 2, 0]]) + + # Test action 0 (Assuming it represents 'up' move) + env.reset(init_board=init_board, add_random_tile_flag=False) + obs, reward, done, info = env.step(0) + expected_board_up = np.array([[8, 8, 2, 0], + [4, 0, 0, 0], + [2, 0, 0, 0], + [0, 0, 0, 0]]) + assert np.array_equal(env.board, expected_board_up) + + # Test action 1 (Assuming it represents 'right' move) + env.reset(init_board=init_board, add_random_tile_flag=False) + obs, reward, done, info = env.step(1) + expected_board_right = np.array([[0, 0, 8, 4], + [0, 0, 0, 2], + [0, 0, 0, 2], + [0, 2, 4, 2]]) + assert np.array_equal(env.board, expected_board_right) + + # Test action 2 (Assuming it represents 'down' move) + env.reset(init_board=init_board, add_random_tile_flag=False) + obs, reward, done, info = env.step(2) + expected_board_down = np.array([[0, 0, 0, 0], + [8, 0, 0, 0], + [2, 0, 0, 0], + [4, 8, 2, 0]]) + assert np.array_equal(env.board, expected_board_down) + + # Test action 3 (Assuming it represents 'left' move) + env.reset(init_board=init_board, add_random_tile_flag=False) + obs, reward, done, info = env.step(3) + expected_board_left = np.array([[8, 4, 0, 0], + [2, 0, 0, 0], + [2, 0, 0, 0], + [2, 4, 2, 0]]) + assert np.array_equal(env.board, expected_board_left) + + +def test_step_action_case2(env): + init_board = np.array([[8, 4, 2, 0], + [2, 0, 2, 0], + [2, 2, 4, 0], + [2, 4, 2, 0]]) + + # Test action 0 (Assuming it represents 'up' move) + env.reset(init_board=init_board, add_random_tile_flag=False) + obs, reward, done, info = env.step(0) + expected_board_up = np.array([[8, 4, 4, 0], + [4, 2, 4, 0], + [2, 4, 2, 0], + [0, 0, 0, 0]]) + assert np.array_equal(env.board, expected_board_up) + + # Test action 1 (Assuming it represents 'right' move) + env.reset(init_board=init_board, add_random_tile_flag=False) + obs, reward, done, info = env.step(1) + expected_board_right = np.array([[0, 8, 4, 2], + [0, 0, 0, 4], + [0, 0, 4, 4], + [0, 2, 4, 2]]) + assert np.array_equal(env.board, expected_board_right) + + # Test action 2 (Assuming it represents 'down' move) + env.reset(init_board=init_board, add_random_tile_flag=False) + obs, reward, done, info = env.step(2) + expected_board_down = np.array([[0, 0, 0, 0], + [8, 4, 4, 0], + [2, 2, 4, 0], + [4, 4, 2, 0]]) + assert np.array_equal(env.board, expected_board_down) + + # Test action 3 (Assuming it represents 'left' move) + env.reset(init_board=init_board, add_random_tile_flag=False) + obs, reward, done, info = env.step(3) + expected_board_left = np.array([[8, 4, 2, 0], + [4, 0, 0, 0], + [4, 4, 0, 0], + [2, 4, 2, 0]]) + assert np.array_equal(env.board, expected_board_left) \ No newline at end of file From 02046f4fb4586156056cc7c27c0f3d1a5ba9ddaf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= <2402552459@qq.com> Date: Tue, 5 Sep 2023 17:59:12 +0800 Subject: [PATCH 21/28] feature(pu): add stochastic_muzero_model_mlp --- .../ctree_stochastic_muzero/lib/cnode.cpp | 7 - lzero/mcts/ptree/ptree_stochastic_mz.py | 29 +- lzero/model/stochastic_muzero_model.py | 22 +- lzero/model/stochastic_muzero_model_mlp.py | 823 ++++++++++++++++++ lzero/policy/stochastic_muzero.py | 18 +- .../cartpole_stochastic_muzero_config.py | 25 +- 6 files changed, 868 insertions(+), 56 deletions(-) create mode 100644 lzero/model/stochastic_muzero_model_mlp.py diff --git a/lzero/mcts/ctree/ctree_stochastic_muzero/lib/cnode.cpp b/lzero/mcts/ctree/ctree_stochastic_muzero/lib/cnode.cpp index cf3c8d1e2..004b11099 100644 --- a/lzero/mcts/ctree/ctree_stochastic_muzero/lib/cnode.cpp +++ b/lzero/mcts/ctree/ctree_stochastic_muzero/lib/cnode.cpp @@ -580,7 +580,6 @@ namespace tree for (auto leaf_order = 0; leaf_order < leaf_idx_list.size(); ++leaf_order) { int i = leaf_idx_list[leaf_order]; - // Your code here } for (int leaf_order = 0; leaf_order < leaf_idx_list.size(); ++leaf_order) { @@ -589,12 +588,6 @@ namespace tree cbackpropagate(results.search_paths[i], min_max_stats_lst->stats_lst[i], to_play_batch[i], values[leaf_order], discount_factor); } - - // for (int i = 0; i < results.num; ++i) - // { - // results.nodes[i]->expand(to_play_batch[i], current_latent_state_index, i, value_prefixs[i], policies[i], is_chance_list[i]); - // cbackpropagate(results.search_paths[i], min_max_stats_lst->stats_lst[i], to_play_batch[i], values[i], discount_factor); - // } } int cselect_child(CNode *root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount_factor, float mean_q, int players) diff --git a/lzero/mcts/ptree/ptree_stochastic_mz.py b/lzero/mcts/ptree/ptree_stochastic_mz.py index b9cf679ee..4384f2ab6 100644 --- a/lzero/mcts/ptree/ptree_stochastic_mz.py +++ b/lzero/mcts/ptree/ptree_stochastic_mz.py @@ -1,5 +1,5 @@ """ -The Node, Roots class and related core functions for MuZero. +The Node, Roots class and related core functions for Stochastic MuZero. """ import math import random @@ -14,7 +14,7 @@ class Node: """ Overview: - the node base class for MuZero. + the node base class for Stochastic MuZero. Arguments: """ @@ -40,7 +40,7 @@ def __init__(self, prior: float, legal_actions: List = None, action_space_size: def expand( self, to_play: int, latent_state_index_in_search_path: int, latent_state_index_in_batch: int, reward: float, - policy_logits: List[float], child_is_chance: bool = False + policy_logits: List[float], child_is_chance: bool = True ) -> None: """ Overview: @@ -55,14 +55,11 @@ def expand( self.to_play = to_play self.reward = reward - # assert (self.is_chance != child_is_chance), f"is_chance and child_is_chance should be different, current is {self.is_chance}-{child_is_chance}, " - if self.is_chance is True: child_is_chance = False self.reward = 0.0 if self.legal_actions is None: - # self.legal_actions = np.arange(len(policy_logits)) self.legal_actions = np.arange(self.chance_space_size) self.latent_state_index_in_search_path = latent_state_index_in_search_path self.latent_state_index_in_batch = latent_state_index_in_batch @@ -72,7 +69,6 @@ def expand( self.children[action] = Node(prior, is_chance=child_is_chance) else: child_is_chance = True - #self.legal_actions = np.arange(self.chance_space_size) self.legal_actions = np.arange(len(policy_logits)) self.latent_state_index_in_search_path = latent_state_index_in_search_path self.latent_state_index_in_batch = latent_state_index_in_batch @@ -217,11 +213,11 @@ def prepare( """ for i in range(self.root_num): # to_play: int, latent_state_index_in_search_path: int, latent_state_index_in_batch: int, - # TODO(pu): why latent_state_index_in_search_path=0, latent_state_index_in_batch=i? if to_play is None: - self.roots[i].expand(-1, 0, i, rewards[i], policies[i], child_is_chance=True) + # TODO(pu): why latent_state_index_in_search_path=0, latent_state_index_in_batch=i? + self.roots[i].expand(-1, 0, i, rewards[i], policies[i]) else: - self.roots[i].expand(to_play[i], 0, i, rewards[i], policies[i], child_is_chance=True) + self.roots[i].expand(to_play[i], 0, i, rewards[i], policies[i]) self.roots[i].add_exploration_noise(root_noise_weight, noises[i]) self.roots[i].visit_count += 1 @@ -237,9 +233,9 @@ def prepare_no_noise(self, rewards: List[float], policies: List[List[float]], to """ for i in range(self.root_num): if to_play is None: - self.roots[i].expand(-1, 0, i, rewards[i], policies[i], child_is_chance=True) + self.roots[i].expand(-1, 0, i, rewards[i], policies[i]) else: - self.roots[i].expand(to_play[i], 0, i, rewards[i], policies[i], child_is_chance=True) + self.roots[i].expand(to_play[i], 0, i, rewards[i], policies[i]) self.roots[i].visit_count += 1 @@ -516,7 +512,6 @@ def batch_traverse( results.nodes[i] = node # print(f'env {i} one simulation done!') - # return results.nodes, results.latent_state_index_in_search_path, results.latent_state_index_in_batch, results.last_actions, virtual_to_play return results, virtual_to_play @@ -566,10 +561,7 @@ def backpropagate( # TODO(pu): to_play related # true_reward is in the perspective of current player of node - # bootstrap_value = (true_reward if node.to_play == to_play else - true_reward) + discount_factor * bootstrap_value - bootstrap_value = ( - -true_reward if node.to_play == to_play else true_reward - ) + discount_factor * bootstrap_value + bootstrap_value = (-true_reward if node.to_play == to_play else true_reward) + discount_factor * bootstrap_value def batch_backpropagate( @@ -589,7 +581,8 @@ def batch_backpropagate( Backpropagation along the search path to update the attributes. Arguments: - latent_state_index_in_search_path (:obj:`Class Int`): the index of latent state vector. - - discount_factor (:obj:`Class Float`): discount_factor factor used i calculating bootstrapped value, if env is board_games, we set discount_factor=1. + - discount_factor (:obj:`Class Float`): discount_factor factor used i calculating bootstrapped value, + if env is board_games, we set discount_factor=1. - value_prefixs (:obj:`Class List`): the value prefixs of nodes along the search path. - values (:obj:`Class List`): the values to propagate along the search path. - policies (:obj:`Class List`): the policy logits of nodes along the search path. diff --git a/lzero/model/stochastic_muzero_model.py b/lzero/model/stochastic_muzero_model.py index 17d24d06c..87700ea17 100644 --- a/lzero/model/stochastic_muzero_model.py +++ b/lzero/model/stochastic_muzero_model.py @@ -1,7 +1,3 @@ -""" -Overview: - BTW, users can refer to the unittest of these model templates to learn how to use them. -""" from typing import Optional, Tuple import math @@ -126,7 +122,7 @@ def __init__( downsample, ) - self.encoder = ChanceEncoder( + self.chance_encoder = ChanceEncoder( observation_shape, chance_space_size ) self.dynamics_network = DynamicsNetwork( @@ -293,7 +289,7 @@ def _representation(self, observation: torch.Tensor) -> torch.Tensor: return latent_state def chance_encode(self, observation: torch.Tensor): - output = self.encoder(observation) + output = self.chance_encoder(observation) return output def _prediction(self, latent_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: @@ -811,11 +807,11 @@ def __init__(self, observation_space_dimensions, action_dimension): # Specify the action space for the model self.action_space = action_dimension # Define the encoder, which transforms observations into a latent space - self.encoder = ChanceEncoderBackbone(observation_space_dimensions, action_dimension) + self.chance_encoder = ChanceEncoderBackbone(observation_space_dimensions, action_dimension) # Using the Straight Through Estimator method for backpropagation self.onehot_argmax = StraightThroughEstimator() - def forward(self, o_i): + def forward(self, observations): """ Forward method for the ChanceEncoder. This method takes an observation and applies the encoder to transform it to a latent space. Then applies the @@ -826,17 +822,17 @@ def forward(self, o_i): Chance Outcomes section. Args: - o_i (Tensor): Observation tensor. + observations (Tensor): Observation tensor. Returns: chance_t (Tensor): Transformed tensor after applying one-hot argmax. - chance_encoding_t (Tensor): Encoding of the input observation tensor. + chance_encoding (Tensor): Encoding of the input observation tensor. """ # Apply the encoder to the observation - chance_encoding_t = self.encoder(o_i) + chance_encoding = self.chance_encoder(observations) # Apply one-hot argmax to the encoding - chance_onehot_t = self.onehot_argmax(chance_encoding_t) - return chance_encoding_t, chance_onehot_t + chance_onehot = self.onehot_argmax(chance_encoding) + return chance_encoding, chance_onehot class StraightThroughEstimator(nn.Module): diff --git a/lzero/model/stochastic_muzero_model_mlp.py b/lzero/model/stochastic_muzero_model_mlp.py new file mode 100644 index 000000000..99e70f9b5 --- /dev/null +++ b/lzero/model/stochastic_muzero_model_mlp.py @@ -0,0 +1,823 @@ +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from ding.torch_utils import MLP +from ding.utils import MODEL_REGISTRY, SequenceType + +from .common import MZNetworkOutput, RepresentationNetworkMLP, PredictionNetworkMLP +from .stochastic_muzero_model import StraightThroughEstimator +from .utils import renormalize, get_params_mean, get_dynamic_mean, get_reward_mean + + +@MODEL_REGISTRY.register('StochasticMuZeroModelMLP') +class StochasticMuZeroModelMLP(nn.Module): + + def __init__( + self, + observation_shape: int = 2, + action_space_size: int = 6, + chance_space_size: int = 2, + latent_state_dim: int = 256, + fc_reward_layers: SequenceType = [32], + fc_value_layers: SequenceType = [32], + fc_policy_layers: SequenceType = [32], + reward_support_size: int = 601, + value_support_size: int = 601, + proj_hid: int = 1024, + proj_out: int = 1024, + pred_hid: int = 512, + pred_out: int = 1024, + self_supervised_learning_loss: bool = False, + categorical_distribution: bool = True, + activation: Optional[nn.Module] = nn.ReLU(inplace=True), + last_linear_layer_init_zero: bool = True, + state_norm: bool = False, + discrete_action_encoding_type: str = 'one_hot', + norm_type: Optional[str] = 'BN', + res_connection_in_dynamics: bool = False, + *args, + **kwargs + ): + """ + Overview: + The definition of the network model of Stochastic, which is a generalization version for 1D vector obs. + The networks are mainly built on fully connected layers. + The representation network is an MLP network which maps the raw observation to a latent state. + The dynamics network is an MLP network which predicts the next latent state, and reward given the current latent state and action. + The prediction network is an MLP network which predicts the value and policy given the current latent state. + Arguments: + - observation_shape (:obj:`int`): Observation space shape, e.g. 8 for Lunarlander. + - action_space_size: (:obj:`int`): Action space size, usually an integer number for discrete action space. + - action_space_size: (:obj:`int`): Action space size, e.g. 4 for Lunarlander. + - latent_state_dim (:obj:`int`): The dimension of latent state, such as 256. + - fc_reward_layers (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). + - fc_value_layers (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). + - fc_policy_layers (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). + - reward_support_size (:obj:`int`): The size of categorical reward output + - value_support_size (:obj:`int`): The size of categorical value output. + - proj_hid (:obj:`int`): The size of projection hidden layer. + - proj_out (:obj:`int`): The size of projection output layer. + - pred_hid (:obj:`int`): The size of prediction hidden layer. + - pred_out (:obj:`int`): The size of prediction output layer. + - self_supervised_learning_loss (:obj:`bool`): Whether to use self_supervised_learning related networks in Stochastic model, default set it to False. + - categorical_distribution (:obj:`bool`): Whether to use discrete support to represent categorical distribution for value, reward/value_prefix. + - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ + operation to speedup, e.g. ReLU(inplace=True). + - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of value/policy mlp, default sets it to True. + - state_norm (:obj:`bool`): Whether to use normalization for latent states, default sets it to True. + - discrete_action_encoding_type (:obj:`str`): The encoding type of discrete action, which can be 'one_hot' or 'not_one_hot'. + - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + - res_connection_in_dynamics (:obj:`bool`): Whether to use residual connection for dynamics network, default set it to False. + """ + super(StochasticMuZeroModelMLP, self).__init__() + self.categorical_distribution = categorical_distribution + if not self.categorical_distribution: + self.reward_support_size = 1 + self.value_support_size = 1 + else: + self.reward_support_size = reward_support_size + self.value_support_size = value_support_size + + self.action_space_size = action_space_size + self.chance_space_size = chance_space_size + + self.continuous_action_space = False + # The dim of action space. For discrete action space, it is 1. + # For continuous action space, it is the dimension of continuous action. + self.action_space_dim = action_space_size if self.continuous_action_space else 1 + assert discrete_action_encoding_type in ['one_hot', 'not_one_hot'], discrete_action_encoding_type + self.discrete_action_encoding_type = discrete_action_encoding_type + if self.continuous_action_space: + self.action_encoding_dim = action_space_size + else: + if self.discrete_action_encoding_type == 'one_hot': + self.action_encoding_dim = action_space_size + elif self.discrete_action_encoding_type == 'not_one_hot': + self.action_encoding_dim = 1 + + self.latent_state_dim = latent_state_dim + self.proj_hid = proj_hid + self.proj_out = proj_out + self.pred_hid = pred_hid + self.pred_out = pred_out + self.self_supervised_learning_loss = self_supervised_learning_loss + self.last_linear_layer_init_zero = last_linear_layer_init_zero + self.state_norm = state_norm + self.res_connection_in_dynamics = res_connection_in_dynamics + + self.representation_network = RepresentationNetworkMLP( + observation_shape=observation_shape, hidden_channels=self.latent_state_dim, norm_type=norm_type + ) + + self.chance_encoder = ChanceEncoder(observation_shape*2, chance_space_size) # input is two concatenated frames + self.dynamics_network = DynamicsNetwork( + action_encoding_dim=self.action_encoding_dim, + num_channels=self.latent_state_dim + self.action_encoding_dim, + common_layer_num=2, + fc_reward_layers=fc_reward_layers, + output_support_size=self.reward_support_size, + last_linear_layer_init_zero=self.last_linear_layer_init_zero, + norm_type=norm_type, + res_connection_in_dynamics=self.res_connection_in_dynamics, + ) + + self.prediction_network = PredictionNetworkMLP( + action_space_size=action_space_size, + num_channels=latent_state_dim, + fc_value_layers=fc_value_layers, + fc_policy_layers=fc_policy_layers, + output_support_size=self.value_support_size, + last_linear_layer_init_zero=self.last_linear_layer_init_zero, + norm_type=norm_type + ) + + self.afterstate_dynamics_network = AfterstateDynamicsNetwork( + action_encoding_dim=self.action_encoding_dim, + num_channels=self.latent_state_dim + self.action_encoding_dim, + common_layer_num=2, + fc_reward_layers=fc_reward_layers, + output_support_size=self.reward_support_size, + last_linear_layer_init_zero=self.last_linear_layer_init_zero, + norm_type=norm_type, + res_connection_in_dynamics=self.res_connection_in_dynamics, + ) + self.afterstate_prediction_network = AfterstatePredictionNetworkMLP( + chance_space_size=chance_space_size, + num_channels=latent_state_dim, + fc_value_layers=fc_value_layers, + fc_policy_layers=fc_policy_layers, + output_support_size=self.value_support_size, + last_linear_layer_init_zero=self.last_linear_layer_init_zero, + norm_type=norm_type + ) + + if self.self_supervised_learning_loss: + # self_supervised_learning_loss related network proposed in EfficientZero + self.projection_input_dim = latent_state_dim + + self.projection = nn.Sequential( + nn.Linear(self.projection_input_dim, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, + nn.Linear(self.proj_hid, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, + nn.Linear(self.proj_hid, self.proj_out), nn.BatchNorm1d(self.proj_out) + ) + self.prediction_head = nn.Sequential( + nn.Linear(self.proj_out, self.pred_hid), + nn.BatchNorm1d(self.pred_hid), + activation, + nn.Linear(self.pred_hid, self.pred_out), + ) + + def initial_inference(self, obs: torch.Tensor) -> MZNetworkOutput: + """ + Overview: + Initial inference of Stochastic model, which is the first step of the Stochastic model. + To perform the initial inference, we first use the representation network to obtain the "latent_state" of the observation. + Then we use the prediction network to predict the "value" and "policy_logits" of the "latent_state", and + also prepare the zeros-like ``reward`` for the next step of the Stochastic model. + Arguments: + - obs (:obj:`torch.Tensor`): The 1D vector observation data. + Returns (MZNetworkOutput): + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + - value_prefix (:obj:`torch.Tensor`): The predicted prefix sum of value for input state. \ + In initial inference, we set it to zero vector. + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The hidden state of LSTM about reward. In initial inference, \ + we set it to the zeros-like hidden state (H and C). + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, obs_shape)`, where B is batch_size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. + - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + """ + batch_size = obs.size(0) + latent_state = self._representation(obs) + policy_logits, value = self._prediction(latent_state) + return MZNetworkOutput( + value, + [0. for _ in range(batch_size)], + policy_logits, + latent_state, + ) + + def recurrent_inference(self, state: torch.Tensor, option: torch.Tensor, + afterstate: bool = False) -> MZNetworkOutput: + """ + Overview: + Recurrent inference of Stochastic MuZero model, which is the rollout step of the Stochastic MuZero model. + To perform the recurrent inference, we first use the dynamics network to predict ``next_latent_state``, + ``reward``, by the given current ``latent_state`` and ``action``. + We then use the prediction network to predict the ``value`` and ``policy_logits`` of the current + ``latent_state``. + Arguments: + - state (:obj:`torch.Tensor`): The encoding latent state of input state or the afterstate. + - option (:obj:`torch.Tensor`): The action to rollout or the chance to predict next latent state. + - afterstate (:obj:`bool`): Whether to use afterstate prediction network to predict next latent state. + Returns (MZNetworkOutput): + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + - reward (:obj:`torch.Tensor`): The predicted reward of input state and selected action. + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - next_latent_state (:obj:`torch.Tensor`): The predicted next latent state. + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, num_channel, obs_shape[1], obs_shape[2])`, where B is batch_size. + - action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. + - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. + - latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ + latent state, W_ is the width of latent state. + - next_latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ + latent state, W_ is the width of latent state. + """ + + if afterstate: + # state is afterstate, option is chance + next_latent_state, reward = self._dynamics(state, option) + policy_logits, value = self._prediction(next_latent_state) + return MZNetworkOutput(value, reward, policy_logits, next_latent_state) + else: + # state is latent_state, option is action + next_afterstate, reward = self._afterstate_dynamics(state, option) + policy_logits, value = self._afterstate_prediction(next_afterstate) + return MZNetworkOutput(value, reward, policy_logits, next_afterstate) + + def _representation(self, observation: torch.Tensor) -> Tuple[torch.Tensor]: + """ + Overview: + Use the representation network to encode the observations into latent state. + Arguments: + - obs (:obj:`torch.Tensor`): The 1D vector observation data. + Returns: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, obs_shape)`, where B is batch_size. + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + """ + latent_state = self.representation_network(observation) + if self.state_norm: + latent_state = renormalize(latent_state) + return latent_state + + def chance_encode(self, observation: torch.Tensor): + output = self.chance_encoder(observation) + return output + + def _prediction(self, latent_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Use the representation network to encode the observations into latent state. + Arguments: + - obs (:obj:`torch.Tensor`): The 1D vector observation data. + Returns: + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + Shapes: + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. + """ + policy_logits, value = self.prediction_network(latent_state) + return policy_logits, value + + def _afterstate_prediction(self, afterstate: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Use the prediction network to predict ``policy_logits`` and ``value``. + Arguments: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + Returns: + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + Shapes: + - latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ + latent state, W_ is the width of latent state. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. + """ + return self.afterstate_prediction_network(afterstate) + def _dynamics(self, latent_state: torch.Tensor, action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Concatenate ``latent_state`` and ``action`` and use the dynamics network to predict ``next_latent_state`` + ``reward`` and ``next_reward_hidden_state``. + Arguments: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The input hidden state of LSTM about reward. + - action (:obj:`torch.Tensor`): The predicted action to rollout. + Returns: + - next_latent_state (:obj:`torch.Tensor`): The predicted latent state of the next timestep. + - next_reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The output hidden state of LSTM about reward. + - reward (:obj:`torch.Tensor`): The predicted reward for input state. + Shapes: + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size. + - next_latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + """ + # NOTE: the discrete action encoding type is important for some environments + + # discrete action space + if self.discrete_action_encoding_type == 'one_hot': + # Stack latent_state with the one hot encoded action + if len(action.shape) == 1: + # (batch_size, ) -> (batch_size, 1) + # e.g., torch.Size([8]) -> torch.Size([8, 1]) + action = action.unsqueeze(-1) + + # transform action to one-hot encoding. + # action_one_hot shape: (batch_size, action_space_size), e.g., (8, 4) + action_one_hot = torch.zeros(action.shape[0], self.action_space_size, device=action.device) + # transform action to torch.int64 + action = action.long() + action_one_hot.scatter_(1, action, 1) + action_encoding = action_one_hot + elif self.discrete_action_encoding_type == 'not_one_hot': + action_encoding = action / self.action_space_size + if len(action_encoding.shape) == 1: + # (batch_size, ) -> (batch_size, 1) + # e.g., torch.Size([8]) -> torch.Size([8, 1]) + action_encoding = action_encoding.unsqueeze(-1) + + action_encoding = action_encoding.to(latent_state.device).float() + # state_action_encoding shape: (batch_size, latent_state[1] + action_dim]) or + # (batch_size, latent_state[1] + action_space_size]) depending on the discrete_action_encoding_type. + state_action_encoding = torch.cat((latent_state, action_encoding), dim=1) + + next_latent_state, reward = self.dynamics_network(state_action_encoding) + + if not self.state_norm: + return next_latent_state, reward + else: + next_latent_state_normalized = renormalize(next_latent_state) + return next_latent_state_normalized, reward + + def _afterstate_dynamics(self, latent_state: torch.Tensor, action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Concatenate ``latent_state`` and ``action`` and use the dynamics network to predict ``next_latent_state`` + ``reward`` and ``next_reward_hidden_state``. + Arguments: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The input hidden state of LSTM about reward. + - action (:obj:`torch.Tensor`): The predicted action to rollout. + Returns: + - next_latent_state (:obj:`torch.Tensor`): The predicted latent state of the next timestep. + - next_reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The output hidden state of LSTM about reward. + - reward (:obj:`torch.Tensor`): The predicted reward for input state. + Shapes: + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size. + - next_latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + """ + # NOTE: the discrete action encoding type is important for some environments + + # discrete action space + if self.discrete_action_encoding_type == 'one_hot': + # Stack latent_state with the one hot encoded action + if len(action.shape) == 1: + # (batch_size, ) -> (batch_size, 1) + # e.g., torch.Size([8]) -> torch.Size([8, 1]) + action = action.unsqueeze(-1) + + # transform action to one-hot encoding. + # action_one_hot shape: (batch_size, action_space_size), e.g., (8, 4) + action_one_hot = torch.zeros(action.shape[0], self.action_space_size, device=action.device) + # transform action to torch.int64 + action = action.long() + action_one_hot.scatter_(1, action, 1) + action_encoding = action_one_hot + elif self.discrete_action_encoding_type == 'not_one_hot': + action_encoding = action / self.action_space_size + if len(action_encoding.shape) == 1: + # (batch_size, ) -> (batch_size, 1) + # e.g., torch.Size([8]) -> torch.Size([8, 1]) + action_encoding = action_encoding.unsqueeze(-1) + + action_encoding = action_encoding.to(latent_state.device).float() + # state_action_encoding shape: (batch_size, latent_state[1] + action_dim]) or + # (batch_size, latent_state[1] + action_space_size]) depending on the discrete_action_encoding_type. + state_action_encoding = torch.cat((latent_state, action_encoding), dim=1) + + next_latent_state, reward = self.dynamics_network(state_action_encoding) + + if not self.state_norm: + return next_latent_state, reward + else: + next_latent_state_normalized = renormalize(next_latent_state) + return next_latent_state_normalized, reward + + def project(self, latent_state: torch.Tensor, with_grad=True) -> torch.Tensor: + """ + Overview: + Project the latent state to a lower dimension to calculate the self-supervised loss, which is proposed in EfficientZero. + For more details, please refer to the paper ``Exploring Simple Siamese Representation Learning``. + Arguments: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - with_grad (:obj:`bool`): Whether to calculate gradient for the projection result. + Returns: + - proj (:obj:`torch.Tensor`): The result embedding vector of projection operation. + Shapes: + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - proj (:obj:`torch.Tensor`): :math:`(B, projection_output_dim)`, where B is batch_size. + + Examples: + >>> latent_state = torch.randn(256, 64) + >>> output = self.project(latent_state) + >>> output.shape # (256, 1024) + """ + proj = self.projection(latent_state) + + if with_grad: + # with grad, use prediction_head + return self.prediction_head(proj) + else: + return proj.detach() + + def get_params_mean(self) -> float: + return get_params_mean(self) + + +class DynamicsNetwork(nn.Module): + + def __init__( + self, + action_encoding_dim: int = 2, + num_channels: int = 64, + common_layer_num: int = 2, + fc_reward_layers: SequenceType = [32], + output_support_size: int = 601, + last_linear_layer_init_zero: bool = True, + activation: Optional[nn.Module] = nn.ReLU(inplace=True), + norm_type: Optional[str] = 'BN', + res_connection_in_dynamics: bool = False, + ): + """ + Overview: + The definition of dynamics network in Stochastic algorithm, which is used to predict next latent state + reward by the given current latent state and action. + The networks are mainly built on fully connected layers. + Arguments: + - action_encoding_dim (:obj:`int`): The dimension of action encoding. + - num_channels (:obj:`int`): The num of channels in latent states. + - common_layer_num (:obj:`int`): The number of common layers in dynamics network. + - fc_reward_layers (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). + - output_support_size (:obj:`int`): The size of categorical reward output. + - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of value/policy mlp, default sets it to True. + - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ + operation to speedup, e.g. ReLU(inplace=True). + - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + - res_connection_in_dynamics (:obj:`bool`): Whether to use residual connection in dynamics network. + """ + super().__init__() + self.num_channels = num_channels + self.action_encoding_dim = action_encoding_dim + self.latent_state_dim = self.num_channels - self.action_encoding_dim + + self.res_connection_in_dynamics = res_connection_in_dynamics + if self.res_connection_in_dynamics: + self.fc_dynamics_1 = MLP( + in_channels=self.num_channels, + hidden_channels=self.latent_state_dim, + layer_num=common_layer_num, + out_channels=self.latent_state_dim, + activation=activation, + norm_type=norm_type, + output_activation=True, + output_norm=True, + # last_linear_layer_init_zero=False is important for convergence + last_linear_layer_init_zero=False, + ) + self.fc_dynamics_2 = MLP( + in_channels=self.latent_state_dim, + hidden_channels=self.latent_state_dim, + layer_num=common_layer_num, + out_channels=self.latent_state_dim, + activation=activation, + norm_type=norm_type, + output_activation=True, + output_norm=True, + # last_linear_layer_init_zero=False is important for convergence + last_linear_layer_init_zero=False, + ) + else: + self.fc_dynamics = MLP( + in_channels=self.num_channels, + hidden_channels=self.latent_state_dim, + layer_num=common_layer_num, + out_channels=self.latent_state_dim, + activation=activation, + norm_type=norm_type, + output_activation=True, + output_norm=True, + # last_linear_layer_init_zero=False is important for convergence + last_linear_layer_init_zero=False, + ) + + self.fc_reward_head = MLP( + in_channels=self.latent_state_dim, + hidden_channels=fc_reward_layers[0], + layer_num=2, + out_channels=output_support_size, + activation=activation, + norm_type=norm_type, + output_activation=False, + output_norm=False, + last_linear_layer_init_zero=last_linear_layer_init_zero + ) + + def forward(self, state_action_encoding: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Forward computation of the dynamics network. Predict the next latent state given current latent state and action. + Arguments: + - state_action_encoding (:obj:`torch.Tensor`): The state-action encoding, which is the concatenation of \ + latent state and action encoding, with shape (batch_size, num_channels, height, width). + Returns: + - next_latent_state (:obj:`torch.Tensor`): The next latent state, with shape (batch_size, latent_state_dim). + - reward (:obj:`torch.Tensor`): The predicted reward for input state. + """ + if self.res_connection_in_dynamics: + # take the state encoding (e.g. latent_state), + # state_action_encoding[:, -self.action_encoding_dim:] is action encoding + latent_state = state_action_encoding[:, :-self.action_encoding_dim] + x = self.fc_dynamics_1(state_action_encoding) + # the residual link: add the latent_state to the state_action encoding + next_latent_state = x + latent_state + next_latent_state_encoding = self.fc_dynamics_2(next_latent_state) + else: + next_latent_state = self.fc_dynamics(state_action_encoding) + next_latent_state_encoding = next_latent_state + + reward = self.fc_reward_head(next_latent_state_encoding) + + return next_latent_state, reward + + def get_dynamic_mean(self) -> float: + return get_dynamic_mean(self) + + def get_reward_mean(self) -> float: + return get_reward_mean(self) + + +class AfterstateDynamicsNetwork(nn.Module): + + def __init__( + self, + action_encoding_dim: int = 2, + num_channels: int = 64, + common_layer_num: int = 2, + fc_reward_layers: SequenceType = [32], + output_support_size: int = 601, + last_linear_layer_init_zero: bool = True, + activation: Optional[nn.Module] = nn.ReLU(inplace=True), + norm_type: Optional[str] = 'BN', + res_connection_in_dynamics: bool = False, + ): + """ + Overview: + The definition of dynamics network in Stochastic algorithm, which is used to predict next latent state + reward by the given current latent state and action. + The networks are mainly built on fully connected layers. + Arguments: + - action_encoding_dim (:obj:`int`): The dimension of action encoding. + - num_channels (:obj:`int`): The num of channels in latent states. + - common_layer_num (:obj:`int`): The number of common layers in dynamics network. + - fc_reward_layers (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). + - output_support_size (:obj:`int`): The size of categorical reward output. + - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of value/policy mlp, default sets it to True. + - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ + operation to speedup, e.g. ReLU(inplace=True). + - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + - res_connection_in_dynamics (:obj:`bool`): Whether to use residual connection in dynamics network. + """ + super().__init__() + self.num_channels = num_channels + self.action_encoding_dim = action_encoding_dim + self.latent_state_dim = self.num_channels - self.action_encoding_dim + + self.res_connection_in_dynamics = res_connection_in_dynamics + if self.res_connection_in_dynamics: + self.fc_dynamics_1 = MLP( + in_channels=self.num_channels, + hidden_channels=self.latent_state_dim, + layer_num=common_layer_num, + out_channels=self.latent_state_dim, + activation=activation, + norm_type=norm_type, + output_activation=True, + output_norm=True, + # last_linear_layer_init_zero=False is important for convergence + last_linear_layer_init_zero=False, + ) + self.fc_dynamics_2 = MLP( + in_channels=self.latent_state_dim, + hidden_channels=self.latent_state_dim, + layer_num=common_layer_num, + out_channels=self.latent_state_dim, + activation=activation, + norm_type=norm_type, + output_activation=True, + output_norm=True, + # last_linear_layer_init_zero=False is important for convergence + last_linear_layer_init_zero=False, + ) + else: + self.fc_dynamics = MLP( + in_channels=self.num_channels, + hidden_channels=self.latent_state_dim, + layer_num=common_layer_num, + out_channels=self.latent_state_dim, + activation=activation, + norm_type=norm_type, + output_activation=True, + output_norm=True, + # last_linear_layer_init_zero=False is important for convergence + last_linear_layer_init_zero=False, + ) + + self.fc_reward_head = MLP( + in_channels=self.latent_state_dim, + hidden_channels=fc_reward_layers[0], + layer_num=2, + out_channels=output_support_size, + activation=activation, + norm_type=norm_type, + output_activation=False, + output_norm=False, + last_linear_layer_init_zero=last_linear_layer_init_zero + ) + + def forward(self, state_action_encoding: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Forward computation of the dynamics network. Predict the next latent state given current latent state and action. + Arguments: + - state_action_encoding (:obj:`torch.Tensor`): The state-action encoding, which is the concatenation of \ + latent state and action encoding, with shape (batch_size, num_channels, height, width). + Returns: + - next_latent_state (:obj:`torch.Tensor`): The next latent state, with shape (batch_size, latent_state_dim). + - reward (:obj:`torch.Tensor`): The predicted reward for input state. + """ + if self.res_connection_in_dynamics: + # take the state encoding (e.g. latent_state), + # state_action_encoding[:, -self.action_encoding_dim:] is action encoding + latent_state = state_action_encoding[:, :-self.action_encoding_dim] + x = self.fc_dynamics_1(state_action_encoding) + # the residual link: add the latent_state to the state_action encoding + next_latent_state = x + latent_state + next_latent_state_encoding = self.fc_dynamics_2(next_latent_state) + else: + next_latent_state = self.fc_dynamics(state_action_encoding) + next_latent_state_encoding = next_latent_state + + reward = self.fc_reward_head(next_latent_state_encoding) + + return next_latent_state, reward + + def get_dynamic_mean(self) -> float: + return get_dynamic_mean(self) + + def get_reward_mean(self) -> float: + return get_reward_mean(self) + + +class AfterstatePredictionNetworkMLP(nn.Module): + + def __init__( + self, + chance_space_size, + num_channels, + common_layer_num: int = 2, + fc_value_layers: SequenceType = [32], + fc_policy_layers: SequenceType = [32], + output_support_size: int = 601, + last_linear_layer_init_zero: bool = True, + activation: Optional[nn.Module] = nn.ReLU(inplace=True), + norm_type: Optional[str] = 'BN', + ): + """ + Overview: + The definition of policy and value prediction network with Multi-Layer Perceptron (MLP), + which is used to predict value and policy by the given latent state. + Arguments: + - chance_space_size: (:obj:`int`): Chance space size, usually an integer number. For discrete action \ + space, it is the number of discrete chance outcomes. + - num_channels (:obj:`int`): The channels of latent states. + - fc_value_layers (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). + - fc_policy_layers (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). + - output_support_size (:obj:`int`): The size of categorical value output. + - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of \ + dynamics/prediction mlp, default sets it to True. + - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ + operation to speedup, e.g. ReLU(inplace=True). + - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + """ + super().__init__() + self.num_channels = num_channels + + # ******* common backbone ****** + self.fc_prediction_common = MLP( + in_channels=self.num_channels, + hidden_channels=self.num_channels, + out_channels=self.num_channels, + layer_num=common_layer_num, + activation=activation, + norm_type=norm_type, + output_activation=True, + output_norm=True, + # last_linear_layer_init_zero=False is important for convergence + last_linear_layer_init_zero=False, + ) + + # ******* value and policy head ****** + self.fc_value_head = MLP( + in_channels=self.num_channels, + hidden_channels=fc_value_layers[0], + out_channels=output_support_size, + layer_num=len(fc_value_layers) + 1, + activation=activation, + norm_type=norm_type, + output_activation=False, + output_norm=False, + # last_linear_layer_init_zero=True is beneficial for convergence speed. + last_linear_layer_init_zero=last_linear_layer_init_zero + ) + self.fc_policy_head = MLP( + in_channels=self.num_channels, + hidden_channels=fc_policy_layers[0], + out_channels=chance_space_size, + layer_num=len(fc_policy_layers) + 1, + activation=activation, + norm_type=norm_type, + output_activation=False, + output_norm=False, + # last_linear_layer_init_zero=True is beneficial for convergence speed. + last_linear_layer_init_zero=last_linear_layer_init_zero + ) + + def forward(self, latent_state: torch.Tensor): + """ + Overview: + Forward computation of the prediction network. + Arguments: + - latent_state (:obj:`torch.Tensor`): input tensor with shape (B, latent_state_dim). + Returns: + - policy (:obj:`torch.Tensor`): policy tensor with shape (B, action_space_size). + - value (:obj:`torch.Tensor`): value tensor with shape (B, output_support_size). + """ + x_prediction_common = self.fc_prediction_common(latent_state) + + value = self.fc_value_head(x_prediction_common) + policy = self.fc_policy_head(x_prediction_common) + return policy, value + +class ChanceEncoderBackbone(nn.Module): + def __init__(self, input_dimensions, chance_encoding_dim=4): + super(ChanceEncoderBackbone, self).__init__() + self.fc1 = nn.Linear(input_dimensions, 128) + self.fc2 = nn.Linear(128, 64) + self.fc3 = nn.Linear(64, chance_encoding_dim) + + def forward(self, x): + x = torch.relu(self.fc1(x)) + x = torch.relu(self.fc2(x)) + x = self.fc3(x) + return x + + +class ChanceEncoder(nn.Module): + def __init__(self, input_dimensions, action_dimension): + super().__init__() + # Specify the action space for the model + self.action_space = action_dimension + # Define the encoder, which transforms observations into a latent space + self.encoder = ChanceEncoderBackbone(input_dimensions, action_dimension) + # Using the Straight Through Estimator method for backpropagation + self.onehot_argmax = StraightThroughEstimator() + + def forward(self, observations): + """ + Forward method for the ChanceEncoder. This method takes an observation + and applies the encoder to transform it to a latent space. Then applies the + StraightThroughEstimator to this encoding. + + References: + Planning in Stochastic Environments with a Learned Model (ICLR 2022), page 5, + Chance Outcomes section. + + Args: + observations (Tensor): Observation tensor. + + Returns: + chance_t (Tensor): Transformed tensor after applying one-hot argmax. + chance_encoding (Tensor): Encoding of the input observation tensor. + """ + # Apply the encoder to the observation + chance_encoding = self.encoder(observations) + # Apply one-hot argmax to the encoding + chance_onehot = self.onehot_argmax(chance_encoding) + return chance_encoding, chance_onehot \ No newline at end of file diff --git a/lzero/policy/stochastic_muzero.py b/lzero/policy/stochastic_muzero.py index b375bba88..63c603269 100644 --- a/lzero/policy/stochastic_muzero.py +++ b/lzero/policy/stochastic_muzero.py @@ -295,7 +295,10 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in for i in range(self._cfg.num_unroll_steps): beg_index = self._cfg.model.image_channel * i end_index = self._cfg.model.image_channel * (i + self._cfg.model.frame_stack_num) - obs_list_for_chance_encoder.append(obs_target_batch[:, beg_index:end_index, :, :]) + if self._cfg.model.model_type == 'conv': + obs_list_for_chance_encoder.append(obs_target_batch[:, beg_index:end_index, :, :]) + elif self._cfg.model.model_type == 'mlp': + obs_list_for_chance_encoder.append(obs_target_batch[:, beg_index*self._cfg.model.observation_shape:end_index*self._cfg.model.observation_shape]) # do augmentations if self._cfg.use_augmentation: @@ -380,7 +383,10 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in latent_state, action_batch[:, step_i], afterstate=False ) afterstate, afterstate_reward, afterstate_value, afterstate_policy_logits = mz_network_output_unpack(network_output) - + + # ============================================================== + # encode the consecutive frames to predict chance + # ============================================================== # concat consecutive frames to predict chance former_frame = obs_list_for_chance_encoder[step_i] latter_frame = obs_list_for_chance_encoder[step_i + 1] @@ -388,8 +394,8 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in chance_encoding, chance_one_hot = self._learn_model.chance_encode(concat_frame) if self._cfg.use_ture_chance_label_in_chance_encoder: true_chance_code = chance_batch[:, step_i] - chance_code = true_chance_code true_chance_one_hot = chance_one_hot_batch[:, step_i] + chance_code = true_chance_code else: chance_code = torch.argmax(chance_encoding, dim=1).long().unsqueeze(-1) @@ -441,11 +447,13 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # TODO(pu): if self._cfg.use_ture_chance_label_in_chance_encoder: afterstate_policy_loss += cross_entropy_loss(afterstate_policy_logits, true_chance_one_hot.detach()) - # The encoder is not used i the mcts, so we don't need to calculate the commitment loss. + # The chance encoder is not used in the mcts, so we don't need to calculate the commitment loss. commitment_loss += torch.nn.MSELoss()(chance_encoding, true_chance_one_hot.float().detach()) else: afterstate_policy_loss += cross_entropy_loss(afterstate_policy_logits, chance_one_hot.detach()) - commitment_loss += torch.nn.MSELoss()(chance_encoding, chance_one_hot.float().detach()) + # TODO(pu): whether to detach the chance_one_hot in the commitment loss. + # commitment_loss += torch.nn.MSELoss()(chance_encoding, chance_one_hot.float().detach()) + commitment_loss += torch.nn.MSELoss()(chance_encoding, chance_one_hot.float()) afterstate_value_loss += cross_entropy_loss(afterstate_value, target_value_categorical[:, step_i]) value_loss += cross_entropy_loss(value, target_value_categorical[:, step_i + 1]) diff --git a/zoo/classic_control/cartpole/config/cartpole_stochastic_muzero_config.py b/zoo/classic_control/cartpole/config/cartpole_stochastic_muzero_config.py index 7a47cb95f..6ea441f28 100644 --- a/zoo/classic_control/cartpole/config/cartpole_stochastic_muzero_config.py +++ b/zoo/classic_control/cartpole/config/cartpole_stochastic_muzero_config.py @@ -11,6 +11,16 @@ batch_size = 256 max_env_step = int(1e5) reanalyze_ratio = 0 + +# debug config +# collector_env_num = 1 +# n_episode = 1 +# evaluator_env_num = 1 +# num_simulations = 2 +# update_per_collect = 2 +# batch_size = 2 +# max_env_step = int(1e5) +# reanalyze_ratio = 0 # ============================================================== # end of the most frequently changed config specified by the user # ============================================================== @@ -31,7 +41,7 @@ observation_shape=4, action_space_size=2, chance_space_size=2, - model_type='conv', + model_type='mlp', lstm_hidden_size=128, latent_state_dim=128, self_supervised_learning_loss=True, # NOTE: default is False. @@ -77,16 +87,5 @@ create_config = cartpole_stochastic_muzero_create_config if __name__ == "__main__": - # Users can use different train entry by specifying the entry_type. - entry_type = "train_muzero" # options={"train_muzero", "train_muzero_with_gym_env"} - - if entry_type == "train_muzero": - from lzero.entry import train_muzero - elif entry_type == "train_muzero_with_gym_env": - """ - The ``train_muzero_with_gym_env`` entry means that the environment used in the training process is generated by wrapping the original gym environment with LightZeroEnvWrapper. - Users can refer to lzero/envs/wrappers for more details. - """ - from lzero.entry import train_muzero_with_gym_env as train_muzero - + from lzero.entry import train_muzero train_muzero([main_config, create_config], seed=0, max_env_step=max_env_step) \ No newline at end of file From 1cbce65ce218af1e04c1ddbd530bab78a40dd301 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= <2402552459@qq.com> Date: Tue, 5 Sep 2023 18:14:29 +0800 Subject: [PATCH 22/28] polish(pu): polish stochastic muzero configs --- lzero/model/stochastic_muzero_model_mlp.py | 2 +- .../lunarlander_disc_stochastic_muzero_config.py | 2 +- .../config/cartpole_stochastic_muzero_config.py | 10 ---------- zoo/game_2048/config/muzero_2048_config.py | 10 ---------- zoo/game_2048/config/stochastic_muzero_2048_config.py | 10 ---------- 5 files changed, 2 insertions(+), 32 deletions(-) diff --git a/lzero/model/stochastic_muzero_model_mlp.py b/lzero/model/stochastic_muzero_model_mlp.py index 99e70f9b5..618428991 100644 --- a/lzero/model/stochastic_muzero_model_mlp.py +++ b/lzero/model/stochastic_muzero_model_mlp.py @@ -109,7 +109,7 @@ def __init__( self.representation_network = RepresentationNetworkMLP( observation_shape=observation_shape, hidden_channels=self.latent_state_dim, norm_type=norm_type ) - + # TODO(pu): different input data type for chance_encoder self.chance_encoder = ChanceEncoder(observation_shape*2, chance_space_size) # input is two concatenated frames self.dynamics_network = DynamicsNetwork( action_encoding_dim=self.action_encoding_dim, diff --git a/zoo/box2d/lunarlander/config/lunarlander_disc_stochastic_muzero_config.py b/zoo/box2d/lunarlander/config/lunarlander_disc_stochastic_muzero_config.py index 444d5386b..21f40584f 100644 --- a/zoo/box2d/lunarlander/config/lunarlander_disc_stochastic_muzero_config.py +++ b/zoo/box2d/lunarlander/config/lunarlander_disc_stochastic_muzero_config.py @@ -31,7 +31,7 @@ observation_shape=8, action_space_size=4, chance_space_size=2, - model_type='conv', + model_type='mlp', lstm_hidden_size=256, latent_state_dim=256, self_supervised_learning_loss=True, # NOTE: default is False. diff --git a/zoo/classic_control/cartpole/config/cartpole_stochastic_muzero_config.py b/zoo/classic_control/cartpole/config/cartpole_stochastic_muzero_config.py index 6ea441f28..4b21a830a 100644 --- a/zoo/classic_control/cartpole/config/cartpole_stochastic_muzero_config.py +++ b/zoo/classic_control/cartpole/config/cartpole_stochastic_muzero_config.py @@ -11,16 +11,6 @@ batch_size = 256 max_env_step = int(1e5) reanalyze_ratio = 0 - -# debug config -# collector_env_num = 1 -# n_episode = 1 -# evaluator_env_num = 1 -# num_simulations = 2 -# update_per_collect = 2 -# batch_size = 2 -# max_env_step = int(1e5) -# reanalyze_ratio = 0 # ============================================================== # end of the most frequently changed config specified by the user # ============================================================== diff --git a/zoo/game_2048/config/muzero_2048_config.py b/zoo/game_2048/config/muzero_2048_config.py index 64fbe475e..fa788fc3e 100644 --- a/zoo/game_2048/config/muzero_2048_config.py +++ b/zoo/game_2048/config/muzero_2048_config.py @@ -16,16 +16,6 @@ reanalyze_ratio = 0. num_of_possible_chance_tile = 2 chance_space_size = 16 * num_of_possible_chance_tile - -# debug config -# collector_env_num = 1 -# n_episode = 1 -# evaluator_env_num = 1 -# num_simulations = 5 -# update_per_collect = 3 -# batch_size = 5 -# max_env_step = int(1e6) -# reanalyze_ratio = 0. # ============================================================== # end of the most frequently changed config specified by the user # ============================================================== diff --git a/zoo/game_2048/config/stochastic_muzero_2048_config.py b/zoo/game_2048/config/stochastic_muzero_2048_config.py index cdd4f37f8..6ec45e1c7 100644 --- a/zoo/game_2048/config/stochastic_muzero_2048_config.py +++ b/zoo/game_2048/config/stochastic_muzero_2048_config.py @@ -17,16 +17,6 @@ reanalyze_ratio = 0. num_of_possible_chance_tile = 2 chance_space_size = 16 * num_of_possible_chance_tile - -# debug config -# collector_env_num = 1 -# n_episode = 1 -# evaluator_env_num = 1 -# num_simulations = 5 -# update_per_collect = 3 -# batch_size = 5 -# max_env_step = int(1e6) -# reanalyze_ratio = 0. # ============================================================== # end of the most frequently changed config specified by the user # ============================================================== From b6e9006b962cb876a03ee38976b112fc9f7617a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= <2402552459@qq.com> Date: Tue, 5 Sep 2023 18:34:04 +0800 Subject: [PATCH 23/28] feature(pu): add analyze utlis for chance distribution --- lzero/policy/stochastic_muzero.py | 15 +- lzero/policy/utils.py | 171 ++++++++++++++++-- .../config/stochastic_muzero_2048_config.py | 1 + 3 files changed, 172 insertions(+), 15 deletions(-) diff --git a/lzero/policy/stochastic_muzero.py b/lzero/policy/stochastic_muzero.py index 63c603269..db53ad8d0 100644 --- a/lzero/policy/stochastic_muzero.py +++ b/lzero/policy/stochastic_muzero.py @@ -16,6 +16,8 @@ from lzero.policy import scalar_transform, InverseScalarTransform, cross_entropy_loss, phi_transform, \ DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, negative_cosine_similarity, prepare_obs, \ configure_optimizers +from lzero.policy.utils import calculate_topk_accuracy, plot_topk_accuracy, visualize_avg_softmax, \ + plot_argmax_distribution @POLICY_REGISTRY.register('stochastic_muzero') @@ -73,6 +75,8 @@ class StochasticMuZeroPolicy(Policy): battle_mode='play_with_bot_mode', # (bool) Whether to monitor extra statistics in tensorboard. monitor_extra_statistics=True, + # (bool) Whether to analyze the chance distribution. + analyze_chance_distribution=False, # (int) The transition number of one ``GameSegment``. game_segment_length=200, @@ -444,9 +448,18 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # ============================================================== policy_loss += cross_entropy_loss(policy_logits, target_policy[:, step_i + 1]) - # TODO(pu): if self._cfg.use_ture_chance_label_in_chance_encoder: afterstate_policy_loss += cross_entropy_loss(afterstate_policy_logits, true_chance_one_hot.detach()) + + if self._cfg.analyze_chance_distribution: + # visualize the avg softmax of afterstate_policy_logits + visualize_avg_softmax(afterstate_policy_logits) + # plot the argmax distribution of true_chance_one_hot + plot_argmax_distribution(true_chance_one_hot) + topK_values = range(1, self._cfg.model.chance_space_size+1) # top_K values from 1 to 32 + # calculate the topK accuracy of afterstate_policy_logits and plot the topK accuracy curve. + plot_topk_accuracy(afterstate_policy_logits, true_chance_one_hot, topK_values) + # The chance encoder is not used in the mcts, so we don't need to calculate the commitment loss. commitment_loss += torch.nn.MSELoss()(chance_encoding, true_chance_one_hot.float().detach()) else: diff --git a/lzero/policy/utils.py b/lzero/policy/utils.py index cd220bd0e..61ceb455e 100644 --- a/lzero/policy/utils.py +++ b/lzero/policy/utils.py @@ -1,18 +1,158 @@ +import inspect import logging from typing import List, Tuple, Dict, Union -from easydict import EasyDict +import matplotlib.pyplot as plt import numpy as np -import torch -from scipy.stats import entropy -import math -import inspect - import torch import torch.nn as nn +from easydict import EasyDict +from scipy.stats import entropy from torch.nn import functional as F +def visualize_avg_softmax(logits): + """ + Overview: + Visualize the average softmax distribution across a minibatch. + Arguments: + logits (Tensor): The logits output from the model. + """ + # Apply softmax to logits to get the probabilities. + probabilities = F.softmax(logits, dim=1) + + # Compute the average probabilities across the minibatch. + avg_probabilities = torch.mean(probabilities, dim=0) + + # Convert to numpy for visualization. + avg_probabilities_np = avg_probabilities.detach().numpy() + + # Create a bar plot. + plt.figure(figsize=(10, 8)) + plt.bar(np.arange(len(avg_probabilities_np)), avg_probabilities_np) + + plt.xlabel('Classes') + plt.ylabel('Average Probability') + plt.title('Average Softmax Probabilities Across the Minibatch') + plt.show() + + +def calculate_topk_accuracy(logits, true_one_hot, top_k): + """ + Overview: + Calculate the top-k accuracy. + Arguments: + logits (Tensor): The logits output from the model. + true_one_hot (Tensor): The one-hot encoded true labels. + top_k (int): The number of top predictions to consider for a match. + Returns: + match_percentage (float): The percentage of matches in top-k predictions. + """ + # Apply softmax to logits to get the probabilities. + probabilities = F.softmax(logits, dim=1) + + # Use topk to find the indices of the highest k probabilities. + topk_indices = torch.topk(probabilities, top_k, dim=1)[1] + + # Get the true labels from the one-hot encoded tensor. + true_labels = torch.argmax(true_one_hot, dim=1).unsqueeze(1) + + # Compare the predicted top-k labels with the true labels. + matches = (topk_indices == true_labels).sum().item() + + # Calculate the percentage of matches. + match_percentage = matches / logits.size(0) * 100 + + return match_percentage + + +def plot_topk_accuracy(afterstate_policy_logits, true_chance_one_hot, top_k_values): + """ + Overview: + Plot the top_K accuracy based on the given afterstate_policy_logits and true_chance_one_hot tensors. + Arguments: + afterstate_policy_logits (torch.Tensor): Tensor of shape (batch_size, num_classes) representing the logits. + true_chance_one_hot (torch.Tensor): Tensor of shape (batch_size, num_classes) representing the one-hot encoded true labels. + top_k_values (range or list): Range or list of top_K values to calculate the accuracy for. + Returns: + None (plots the graph) + """ + match_percentages = [] + for top_k in top_k_values: + match_percentage = calculate_topk_accuracy(afterstate_policy_logits, true_chance_one_hot, top_k=top_k) + match_percentages.append(match_percentage) + + plt.plot(top_k_values, match_percentages) + plt.xlabel('top_K') + plt.ylabel('Match Percentage') + plt.title('Top_K Accuracy') + plt.show() + + +def compare_argmax(afterstate_policy_logits, chance_one_hot): + """ + Overview: + Compare the argmax of afterstate_policy_logits and chance_one_hot tensors. + Arguments: + afterstate_policy_logits (torch.Tensor): Tensor of shape (batch_size, num_classes) representing the logits. + chance_one_hot (torch.Tensor): Tensor of shape (batch_size, num_classes) representing the one-hot encoded labels. + Returns: + None (plots the graph) + Example usage: + >>> afterstate_policy_logits = torch.randn(1024, 32) + >>> chance_one_hot = torch.randn(1024, 32) + >>> compare_argmax(afterstate_policy_logits, chance_one_hot) + """ + + # Calculate the argmax of afterstate_policy_logits and chance_one_hot tensors. + argmax_afterstate = torch.argmax(afterstate_policy_logits, dim=1) + argmax_chance = torch.argmax(chance_one_hot, dim=1) + + # Check if the argmax values are equal. + matches = (argmax_afterstate == argmax_chance) + + # Create a list of sample indices. + sample_indices = list(range(afterstate_policy_logits.size(0))) + + # Create a list to store the equality values (1 for equal, 0 for not equal). + equality_values = [int(match) for match in matches] + + # Plot the equality values. + plt.plot(sample_indices, equality_values) + plt.xlabel('Sample Index') + plt.ylabel('Equality') + plt.title('Comparison of argmax') + plt.show() + + +def plot_argmax_distribution(true_chance_one_hot): + """ + Overview: + Plot the distribution of possible values obtained from argmax(true_chance_one_hot). + Arguments: + true_chance_one_hot (torch.Tensor): Tensor of shape (batch_size, num_classes) representing the one-hot encoded true labels. + Returns: + None (plots the graph) + """ + + # Calculate the argmax of true_chance_one_hot tensor. + argmax_values = torch.argmax(true_chance_one_hot, dim=1) + + # Calculate the count of each unique argmax value. + unique_values, counts = torch.unique(argmax_values, return_counts=True) + + # Convert the tensor to a list for plotting. + unique_values = unique_values.tolist() + counts = counts.tolist() + + # Plot the distribution of argmax values. + plt.bar(unique_values, counts) + plt.xlabel('Argmax Values') + plt.ylabel('Count') + plt.title('Distribution of Argmax Values') + plt.show() + + class LayerNorm(nn.Module): """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """ @@ -26,11 +166,11 @@ def forward(self, input): def configure_optimizers( - model: nn.Module, - weight_decay: float = 0, - learning_rate: float = 3e-3, - betas: tuple = (0.9, 0.999), - device_type: str = "cuda" + model: nn.Module, + weight_decay: float = 0, + learning_rate: float = 3e-3, + betas: tuple = (0.9, 0.999), + device_type: str = "cuda" ): """ Overview: @@ -90,7 +230,7 @@ def configure_optimizers( param_dict = {pn: p for pn, p in model.named_parameters()} inter_params = decay & no_decay union_params = decay | no_decay - assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) + assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),) assert len( param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ % (str(param_dict.keys() - union_params),) @@ -302,7 +442,8 @@ def concat_output(output_lst: List, data_type: str = 'muzero') -> Tuple: ) -def to_torch_float_tensor(data_list: Union[np.ndarray, List[np.ndarray]], device: torch.device) -> Union[torch.Tensor, List[torch.Tensor]]: +def to_torch_float_tensor(data_list: Union[np.ndarray, List[np.ndarray]], device: torch.device) -> Union[ + torch.Tensor, List[torch.Tensor]]: """ Overview: convert the data or data list to torch float tensor @@ -322,7 +463,8 @@ def to_torch_float_tensor(data_list: Union[np.ndarray, List[np.ndarray]], device else: raise TypeError("The type of input must be np.ndarray or List[np.ndarray]") -def to_detach_cpu_numpy(data_list: Union[torch.Tensor, List[torch.Tensor]]) -> Union[np.ndarray,List[np.ndarray]]: + +def to_detach_cpu_numpy(data_list: Union[torch.Tensor, List[torch.Tensor]]) -> Union[np.ndarray, List[np.ndarray]]: """ Overview: convert the data or data list to detach cpu numpy. @@ -341,6 +483,7 @@ def to_detach_cpu_numpy(data_list: Union[torch.Tensor, List[torch.Tensor]]) -> U else: raise TypeError("The type of input must be torch.Tensor or List[torch.Tensor]") + def ez_network_output_unpack(network_output: Dict) -> Tuple: """ Overview: diff --git a/zoo/game_2048/config/stochastic_muzero_2048_config.py b/zoo/game_2048/config/stochastic_muzero_2048_config.py index 6ec45e1c7..5854add7f 100644 --- a/zoo/game_2048/config/stochastic_muzero_2048_config.py +++ b/zoo/game_2048/config/stochastic_muzero_2048_config.py @@ -39,6 +39,7 @@ manager=dict(shared_memory=False, ), ), policy=dict( + model_path='/Users/puyuan/code/LightZero/zoo/game_2048/tb/game_2048_nct-2_stochastic_muzero_ns100_upc200_rr0.0_bs512_chance-True-32_sslw2_rbs1e6_seed0/ckpt/ckpt_best.pth.tar', model=dict( observation_shape=(16, 4, 4), action_space_size=action_space_size, From f4556ce048d3a9b3a07c1e7865a4aae2280c86e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= <2402552459@qq.com> Date: Tue, 5 Sep 2023 21:41:16 +0800 Subject: [PATCH 24/28] polish(pu): delete model_path personal info --- zoo/game_2048/config/stochastic_muzero_2048_config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/zoo/game_2048/config/stochastic_muzero_2048_config.py b/zoo/game_2048/config/stochastic_muzero_2048_config.py index 5854add7f..6ec45e1c7 100644 --- a/zoo/game_2048/config/stochastic_muzero_2048_config.py +++ b/zoo/game_2048/config/stochastic_muzero_2048_config.py @@ -39,7 +39,6 @@ manager=dict(shared_memory=False, ), ), policy=dict( - model_path='/Users/puyuan/code/LightZero/zoo/game_2048/tb/game_2048_nct-2_stochastic_muzero_ns100_upc200_rr0.0_bs512_chance-True-32_sslw2_rbs1e6_seed0/ckpt/ckpt_best.pth.tar', model=dict( observation_shape=(16, 4, 4), action_space_size=action_space_size, From 3b7bcb089768b1ee4e60fc0b759c4aaf1c9cc64f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= <2402552459@qq.com> Date: Mon, 11 Sep 2023 00:15:00 +0800 Subject: [PATCH 25/28] polish(pu): add TestVisualizationFunctions, polish stochastic muzero model, rename xxx_eval_config.py to xxx_eval.py --- .../buffer/game_buffer_stochastic_muzero.py | 25 +- lzero/mcts/buffer/game_segment.py | 48 +- lzero/model/stochastic_muzero_model.py | 47 +- lzero/model/stochastic_muzero_model_mlp.py | 524 ++---------------- lzero/policy/efficientzero.py | 4 + lzero/policy/gumbel_muzero.py | 4 + lzero/policy/muzero.py | 22 +- lzero/policy/sampled_efficientzero.py | 5 +- lzero/policy/stochastic_muzero.py | 40 +- lzero/policy/tests/test_utlis.py | 131 ++++- .../{atari_eval_config.py => atari_eval.py} | 0 ...val_config.py => gomoku_alphazero_eval.py} | 0 ...config.py => gomoku_gumbel_muzero_eval.py} | 0 ...o_eval_config.py => gomoku_muzero_eval.py} | 0 ..._config.py => tictactoe_alphazero_eval.py} | 0 ...val_config.py => tictactoe_muzero_eval.py} | 0 ...r_eval_config.py => bipedalwalker_eval.py} | 0 ...der_eval_config.py => lunarlander_eval.py} | 0 ...rtpole_eval_config.py => cartpole_eval.py} | 0 ...ndulum_eval_config.py => pendulum_eval.py} | 0 zoo/game_2048/config/muzero_2048_config.py | 2 +- .../config/stochastic_muzero_2048_config.py | 2 +- zoo/game_2048/entry/2048_bot_eval.py | 57 ++ ...nfig.py => stochastic_muzero_2048_eval.py} | 8 +- .../expectimax_search_based_bot.py} | 64 +-- zoo/game_2048/envs/game_2048_env.py | 88 ++- 26 files changed, 448 insertions(+), 623 deletions(-) rename zoo/atari/entry/{atari_eval_config.py => atari_eval.py} (100%) rename zoo/board_games/gomoku/entry/{gomoku_alphazero_eval_config.py => gomoku_alphazero_eval.py} (100%) rename zoo/board_games/gomoku/entry/{gomoku_gumbel_muzero_eval_config.py => gomoku_gumbel_muzero_eval.py} (100%) rename zoo/board_games/gomoku/entry/{gomoku_muzero_eval_config.py => gomoku_muzero_eval.py} (100%) rename zoo/board_games/tictactoe/entry/{tictactoe_alphazero_eval_config.py => tictactoe_alphazero_eval.py} (100%) rename zoo/board_games/tictactoe/entry/{tictactoe_muzero_eval_config.py => tictactoe_muzero_eval.py} (100%) rename zoo/box2d/bipedalwalker/entry/{bipedalwalker_eval_config.py => bipedalwalker_eval.py} (100%) rename zoo/box2d/lunarlander/entry/{lunarlander_eval_config.py => lunarlander_eval.py} (100%) rename zoo/classic_control/cartpole/entry/{cartpole_eval_config.py => cartpole_eval.py} (100%) rename zoo/classic_control/pendulum/entry/{pendulum_eval_config.py => pendulum_eval.py} (100%) create mode 100644 zoo/game_2048/entry/2048_bot_eval.py rename zoo/game_2048/entry/{stochastic_muzero_2048_eval_config.py => stochastic_muzero_2048_eval.py} (86%) rename zoo/game_2048/{entry/rule_based_2048_config.py => envs/expectimax_search_based_bot.py} (72%) diff --git a/lzero/mcts/buffer/game_buffer_stochastic_muzero.py b/lzero/mcts/buffer/game_buffer_stochastic_muzero.py index c068f265e..65552b9af 100644 --- a/lzero/mcts/buffer/game_buffer_stochastic_muzero.py +++ b/lzero/mcts/buffer/game_buffer_stochastic_muzero.py @@ -1,4 +1,4 @@ -from typing import Any, Tuple +from typing import Any, Tuple, List import numpy as np from ding.utils import BUFFER_REGISTRY @@ -146,3 +146,26 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: context = reward_value_context, policy_re_context, policy_non_re_context, current_batch return context + + def update_priority(self, train_data: List[np.ndarray], batch_priorities: Any) -> None: + """ + Overview: + Update the priority of training data. + Arguments: + - train_data (:obj:`Optional[List[Optional[np.ndarray]]]`): training data to be updated priority. + - batch_priorities (:obj:`batch_priorities`): priorities to update to. + NOTE: + train_data = [current_batch, target_batch] + if self._cfg.use_ture_chance_label_in_chance_encoder: + obs_batch_orig, action_batch, mask_batch, indices, weights, make_time, chance_batch = current_batch + else: + obs_batch_orig, action_batch, mask_batch, indices, weights, make_time = current_batch + + """ + indices = train_data[0][3] + metas = {'make_time': train_data[0][5], 'batch_priorities': batch_priorities} + # only update the priorities for data still in replay buffer + for i in range(len(indices)): + if metas['make_time'][i] > self.clear_time: + idx, prio = indices[i], metas['batch_priorities'][i] + self.game_pos_priorities[idx] = prio diff --git a/lzero/mcts/buffer/game_segment.py b/lzero/mcts/buffer/game_segment.py index 9b0b64859..6bff11359 100644 --- a/lzero/mcts/buffer/game_segment.py +++ b/lzero/mcts/buffer/game_segment.py @@ -41,11 +41,17 @@ def __init__(self, action_space: int, game_segment_length: int = 200, config: Ea """ self.action_space = action_space self.game_segment_length = game_segment_length - self.config = config - + self.num_unroll_steps = config.num_unroll_steps + self.td_steps = config.td_steps self.frame_stack_num = config.model.frame_stack_num self.discount_factor = config.discount_factor self.action_space_size = config.model.action_space_size + self.gray_scale = config.gray_scale + self.transform2string = config.transform2string + self.sampled_algo = config.sampled_algo + self.gumbel_algo = config.gumbel_algo + self.use_ture_chance_label_in_chance_encoder = config.use_ture_chance_label_in_chance_encoder + if isinstance(config.model.observation_shape, int) or len(config.model.observation_shape) == 1: # for vector obs input, e.g. classical control and box2d environments self.zero_obs_shape = config.model.observation_shape @@ -71,9 +77,9 @@ def __init__(self, action_space: int, game_segment_length: int = 200, config: Ea self.improved_policy_probs = [] - if self.config.sampled_algo: + if self.sampled_algo: self.root_sampled_actions = [] - if self.config.use_ture_chance_label_in_chance_encoder: + if self.use_ture_chance_label_in_chance_encoder: self.chance_segment = [] @@ -92,8 +98,8 @@ def get_unroll_obs(self, timestep: int, num_unroll_steps: int = 0, padding: bool if pad_len > 0: pad_frames = np.array([stacked_obs[-1] for _ in range(pad_len)]) stacked_obs = np.concatenate((stacked_obs, pad_frames)) - if self.config.transform2string: - stacked_obs = [jpeg_data_decompressor(obs, self.config.gray_scale) for obs in stacked_obs] + if self.transform2string: + stacked_obs = [jpeg_data_decompressor(obs, self.gray_scale) for obs in stacked_obs] return stacked_obs def zero_obs(self) -> List: @@ -119,8 +125,8 @@ def get_obs(self) -> List: ) timestep = timestep_reward stacked_obs = self.obs_segment[timestep:timestep + self.frame_stack_num] - if self.config.transform2string: - stacked_obs = [jpeg_data_decompressor(obs, self.config.gray_scale) for obs in stacked_obs] + if self.transform2string: + stacked_obs = [jpeg_data_decompressor(obs, self.gray_scale) for obs in stacked_obs] return stacked_obs def append( @@ -142,7 +148,7 @@ def append( self.action_mask_segment.append(action_mask) self.to_play_segment.append(to_play) - if self.config.use_ture_chance_label_in_chance_encoder: + if self.use_ture_chance_label_in_chance_encoder: self.chance_segment.append(chance) def pad_over( @@ -162,15 +168,15 @@ def pad_over( - next_segment_child_visits (:obj:`list`): root visit count distributions of MCTS from the next game_segment - next_segment_improved_policy (:obj:`list`): root children select policy of MCTS from the next game_segment (Only used in Gumbel MuZero) """ - assert len(next_segment_observations) <= self.config.num_unroll_steps - assert len(next_segment_child_visits) <= self.config.num_unroll_steps - assert len(next_segment_root_values) <= self.config.num_unroll_steps + self.config.td_steps - assert len(next_segment_rewards) <= self.config.num_unroll_steps + self.config.td_steps - 1 + assert len(next_segment_observations) <= self.num_unroll_steps + assert len(next_segment_child_visits) <= self.num_unroll_steps + assert len(next_segment_root_values) <= self.num_unroll_steps + self.num_unroll_steps + assert len(next_segment_rewards) <= self.num_unroll_steps + self.num_unroll_steps - 1 # ============================================================== # The core difference between GumbelMuZero and MuZero # ============================================================== - if self.config.gumbel_algo: - assert len(next_segment_improved_policy) <= self.config.num_unroll_steps + self.config.td_steps + if self.gumbel_algo: + assert len(next_segment_improved_policy) <= self.num_unroll_steps + self.num_unroll_steps # NOTE: next block observation should start from (stacked_observation - 1) in next trajectory for observation in next_segment_observations: @@ -185,10 +191,10 @@ def pad_over( for child_visits in next_segment_child_visits: self.child_visit_segment.append(child_visits) - if self.config.gumbel_algo: + if self.gumbel_algo: for improved_policy in next_segment_improved_policy: self.improved_policy_probs.append(improved_policy) - if self.config.use_ture_chance_label_in_chance_encoder: + if self.use_ture_chance_label_in_chance_encoder: for chances in next_chances: self.chance_segment.append(chances) @@ -210,10 +216,10 @@ def store_search_stats( if idx is None: self.child_visit_segment.append([visit_count / sum_visits for visit_count in visit_counts]) self.root_value_segment.append(root_value) - if self.config.sampled_algo: + if self.sampled_algo: self.root_sampled_actions.append(root_sampled_actions) # store the improved policy in Gumbel Muzero: \pi'=softmax(logits + \sigma(CompletedQ)) - if self.config.gumbel_algo: + if self.gumbel_algo: self.improved_policy_probs.append(improved_policy) else: self.child_visit_segment[idx] = [visit_count / sum_visits for visit_count in visit_counts] @@ -268,7 +274,7 @@ def game_segment_to_array(self) -> None: self.action_mask_segment = np.array(self.action_mask_segment) self.to_play_segment = np.array(self.to_play_segment) - if self.config.use_ture_chance_label_in_chance_encoder: + if self.use_ture_chance_label_in_chance_encoder: self.chance_segment = np.array(self.chance_segment) def reset(self, init_observations: np.ndarray) -> None: @@ -288,7 +294,7 @@ def reset(self, init_observations: np.ndarray) -> None: self.action_mask_segment = [] self.to_play_segment = [] - if self.config.use_ture_chance_label_in_chance_encoder: + if self.use_ture_chance_label_in_chance_encoder: self.chance_segment = [] assert len(init_observations) == self.frame_stack_num diff --git a/lzero/model/stochastic_muzero_model.py b/lzero/model/stochastic_muzero_model.py index 87700ea17..b6e6b7765 100644 --- a/lzero/model/stochastic_muzero_model.py +++ b/lzero/model/stochastic_muzero_model.py @@ -44,13 +44,15 @@ def __init__( ): """ Overview: - The definition of the neural network model used in Stochastic MuZero. - Stochastic MuZero model which consists of a representation network, a dynamics network and a prediction network. - The networks are build on convolution residual blocks and fully connected layers. + The definition of the neural network model used in Stochastic MuZero, + which is proposed in the paper https://openreview.net/pdf?id=X6D9bAHhBQ1. + Stochastic MuZero model consists of a representation network, a dynamics network and a prediction network. + The networks are built on convolution residual blocks and fully connected layers. Arguments: - observation_shape (:obj:`SequenceType`): Observation space shape, e.g. [C, W, H]=[12, 96, 96] for Atari. - action_space_size: (:obj:`int`): Action space size, usually an integer number for discrete action space. - - chance_space_size: (:obj:`int`): Chance space size, the action space for decision node, usually an integer number for discrete action space. + - chance_space_size: (:obj:`int`): Chance space size, the action space for decision node, usually an integer + number for discrete action space. - num_res_blocks (:obj:`int`): The number of res blocks in AlphaZero model. - num_channels (:obj:`int`): The channels of hidden states. - reward_head_channels (:obj:`int`): The channels of reward head. @@ -801,13 +803,33 @@ def forward(self, x): return x +class ChanceEncoderBackboneMLP(nn.Module): + def __init__(self, input_dimensions, chance_encoding_dim=4): + super(ChanceEncoderBackboneMLP, self).__init__() + self.fc1 = nn.Linear(input_dimensions, 128) + self.fc2 = nn.Linear(128, 64) + self.fc3 = nn.Linear(64, chance_encoding_dim) + + def forward(self, x): + x = torch.relu(self.fc1(x)) + x = torch.relu(self.fc2(x)) + x = self.fc3(x) + return x + + class ChanceEncoder(nn.Module): - def __init__(self, observation_space_dimensions, action_dimension): + def __init__(self, input_dimensions, action_dimension, encoder_backbone_type='conv'): super().__init__() # Specify the action space for the model self.action_space = action_dimension - # Define the encoder, which transforms observations into a latent space - self.chance_encoder = ChanceEncoderBackbone(observation_space_dimensions, action_dimension) + if encoder_backbone_type == 'conv': + # Define the encoder, which transforms observations into a latent space + self.chance_encoder = ChanceEncoderBackbone(input_dimensions, action_dimension) + elif encoder_backbone_type == 'mlp': + self.chance_encoder = ChanceEncoderBackboneMLP(input_dimensions, action_dimension) + else: + raise ValueError('Encoder backbone type not supported') + # Using the Straight Through Estimator method for backpropagation self.onehot_argmax = StraightThroughEstimator() @@ -821,11 +843,11 @@ def forward(self, observations): Planning in Stochastic Environments with a Learned Model (ICLR 2022), page 5, Chance Outcomes section. - Args: + Arguments: observations (Tensor): Observation tensor. Returns: - chance_t (Tensor): Transformed tensor after applying one-hot argmax. + chance (Tensor): Transformed tensor after applying one-hot argmax. chance_encoding (Tensor): Encoding of the input observation tensor. """ # Apply the encoder to the observation @@ -844,7 +866,7 @@ def forward(self, x): Forward method for the StraightThroughEstimator. This applies the one-hot argmax function to the input tensor. - Args: + Arguments: x (Tensor): Input tensor. Returns: @@ -872,7 +894,7 @@ def forward(ctx, input): Forward method for the one-hot argmax function. This method transforms the input tensor into a one-hot tensor. - Args: + Arguments: ctx (context): A context object that can be used to stash information for backward computation. input (Tensor): Input tensor. @@ -889,7 +911,7 @@ def backward(ctx, grad_output): Backward method for the one-hot argmax function. This method allows gradients to flow to the encoder during backpropagation. - Args: + Arguments: ctx (context): A context object that was stashed in the forward pass. grad_output (Tensor): The gradient of the output tensor. @@ -897,4 +919,3 @@ def backward(ctx, grad_output): Tensor: The gradient of the input tensor. """ return grad_output - diff --git a/lzero/model/stochastic_muzero_model_mlp.py b/lzero/model/stochastic_muzero_model_mlp.py index 618428991..f9575820e 100644 --- a/lzero/model/stochastic_muzero_model_mlp.py +++ b/lzero/model/stochastic_muzero_model_mlp.py @@ -2,42 +2,42 @@ import torch import torch.nn as nn -from ding.torch_utils import MLP from ding.utils import MODEL_REGISTRY, SequenceType -from .common import MZNetworkOutput, RepresentationNetworkMLP, PredictionNetworkMLP -from .stochastic_muzero_model import StraightThroughEstimator -from .utils import renormalize, get_params_mean, get_dynamic_mean, get_reward_mean +from .common import RepresentationNetworkMLP, PredictionNetworkMLP +from .muzero_model_mlp import DynamicsNetwork +from .stochastic_muzero_model import StochasticMuZeroModel, ChanceEncoder +from .utils import renormalize @MODEL_REGISTRY.register('StochasticMuZeroModelMLP') -class StochasticMuZeroModelMLP(nn.Module): +class StochasticMuZeroModelMLP(StochasticMuZeroModel): def __init__( - self, - observation_shape: int = 2, - action_space_size: int = 6, - chance_space_size: int = 2, - latent_state_dim: int = 256, - fc_reward_layers: SequenceType = [32], - fc_value_layers: SequenceType = [32], - fc_policy_layers: SequenceType = [32], - reward_support_size: int = 601, - value_support_size: int = 601, - proj_hid: int = 1024, - proj_out: int = 1024, - pred_hid: int = 512, - pred_out: int = 1024, - self_supervised_learning_loss: bool = False, - categorical_distribution: bool = True, - activation: Optional[nn.Module] = nn.ReLU(inplace=True), - last_linear_layer_init_zero: bool = True, - state_norm: bool = False, - discrete_action_encoding_type: str = 'one_hot', - norm_type: Optional[str] = 'BN', - res_connection_in_dynamics: bool = False, - *args, - **kwargs + self, + observation_shape: int = 2, + action_space_size: int = 6, + chance_space_size: int = 2, + latent_state_dim: int = 256, + fc_reward_layers: SequenceType = [32], + fc_value_layers: SequenceType = [32], + fc_policy_layers: SequenceType = [32], + reward_support_size: int = 601, + value_support_size: int = 601, + proj_hid: int = 1024, + proj_out: int = 1024, + pred_hid: int = 512, + pred_out: int = 1024, + self_supervised_learning_loss: bool = False, + categorical_distribution: bool = True, + activation: Optional[nn.Module] = nn.ReLU(inplace=True), + last_linear_layer_init_zero: bool = True, + state_norm: bool = False, + discrete_action_encoding_type: str = 'one_hot', + norm_type: Optional[str] = 'BN', + res_connection_in_dynamics: bool = False, + *args, + **kwargs ): """ Overview: @@ -110,7 +110,8 @@ def __init__( observation_shape=observation_shape, hidden_channels=self.latent_state_dim, norm_type=norm_type ) # TODO(pu): different input data type for chance_encoder - self.chance_encoder = ChanceEncoder(observation_shape*2, chance_space_size) # input is two concatenated frames + self.chance_encoder = ChanceEncoder(observation_shape * 2, chance_space_size, + encoder_backbone_type='mlp') # input is two concatenated frames self.dynamics_network = DynamicsNetwork( action_encoding_dim=self.action_encoding_dim, num_channels=self.latent_state_dim + self.action_encoding_dim, @@ -168,136 +169,6 @@ def __init__( nn.Linear(self.pred_hid, self.pred_out), ) - def initial_inference(self, obs: torch.Tensor) -> MZNetworkOutput: - """ - Overview: - Initial inference of Stochastic model, which is the first step of the Stochastic model. - To perform the initial inference, we first use the representation network to obtain the "latent_state" of the observation. - Then we use the prediction network to predict the "value" and "policy_logits" of the "latent_state", and - also prepare the zeros-like ``reward`` for the next step of the Stochastic model. - Arguments: - - obs (:obj:`torch.Tensor`): The 1D vector observation data. - Returns (MZNetworkOutput): - - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. - - value_prefix (:obj:`torch.Tensor`): The predicted prefix sum of value for input state. \ - In initial inference, we set it to zero vector. - - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. - - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. - - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The hidden state of LSTM about reward. In initial inference, \ - we set it to the zeros-like hidden state (H and C). - Shapes: - - obs (:obj:`torch.Tensor`): :math:`(B, obs_shape)`, where B is batch_size. - - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. - - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. - - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. - - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. - """ - batch_size = obs.size(0) - latent_state = self._representation(obs) - policy_logits, value = self._prediction(latent_state) - return MZNetworkOutput( - value, - [0. for _ in range(batch_size)], - policy_logits, - latent_state, - ) - - def recurrent_inference(self, state: torch.Tensor, option: torch.Tensor, - afterstate: bool = False) -> MZNetworkOutput: - """ - Overview: - Recurrent inference of Stochastic MuZero model, which is the rollout step of the Stochastic MuZero model. - To perform the recurrent inference, we first use the dynamics network to predict ``next_latent_state``, - ``reward``, by the given current ``latent_state`` and ``action``. - We then use the prediction network to predict the ``value`` and ``policy_logits`` of the current - ``latent_state``. - Arguments: - - state (:obj:`torch.Tensor`): The encoding latent state of input state or the afterstate. - - option (:obj:`torch.Tensor`): The action to rollout or the chance to predict next latent state. - - afterstate (:obj:`bool`): Whether to use afterstate prediction network to predict next latent state. - Returns (MZNetworkOutput): - - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. - - reward (:obj:`torch.Tensor`): The predicted reward of input state and selected action. - - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. - - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. - - next_latent_state (:obj:`torch.Tensor`): The predicted next latent state. - Shapes: - - obs (:obj:`torch.Tensor`): :math:`(B, num_channel, obs_shape[1], obs_shape[2])`, where B is batch_size. - - action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size. - - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. - - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. - - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. - - latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ - latent state, W_ is the width of latent state. - - next_latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ - latent state, W_ is the width of latent state. - """ - - if afterstate: - # state is afterstate, option is chance - next_latent_state, reward = self._dynamics(state, option) - policy_logits, value = self._prediction(next_latent_state) - return MZNetworkOutput(value, reward, policy_logits, next_latent_state) - else: - # state is latent_state, option is action - next_afterstate, reward = self._afterstate_dynamics(state, option) - policy_logits, value = self._afterstate_prediction(next_afterstate) - return MZNetworkOutput(value, reward, policy_logits, next_afterstate) - - def _representation(self, observation: torch.Tensor) -> Tuple[torch.Tensor]: - """ - Overview: - Use the representation network to encode the observations into latent state. - Arguments: - - obs (:obj:`torch.Tensor`): The 1D vector observation data. - Returns: - - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. - Shapes: - - obs (:obj:`torch.Tensor`): :math:`(B, obs_shape)`, where B is batch_size. - - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. - """ - latent_state = self.representation_network(observation) - if self.state_norm: - latent_state = renormalize(latent_state) - return latent_state - - def chance_encode(self, observation: torch.Tensor): - output = self.chance_encoder(observation) - return output - - def _prediction(self, latent_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Overview: - Use the representation network to encode the observations into latent state. - Arguments: - - obs (:obj:`torch.Tensor`): The 1D vector observation data. - Returns: - - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. - - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. - Shapes: - - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. - - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. - - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. - """ - policy_logits, value = self.prediction_network(latent_state) - return policy_logits, value - - def _afterstate_prediction(self, afterstate: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Overview: - Use the prediction network to predict ``policy_logits`` and ``value``. - Arguments: - - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. - Returns: - - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. - - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. - Shapes: - - latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ - latent state, W_ is the width of latent state. - - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. - - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. - """ - return self.afterstate_prediction_network(afterstate) def _dynamics(self, latent_state: torch.Tensor, action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Overview: @@ -354,7 +225,8 @@ def _dynamics(self, latent_state: torch.Tensor, action: torch.Tensor) -> Tuple[t next_latent_state_normalized = renormalize(next_latent_state) return next_latent_state_normalized, reward - def _afterstate_dynamics(self, latent_state: torch.Tensor, action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def _afterstate_dynamics(self, latent_state: torch.Tensor, action: torch.Tensor) -> Tuple[ + torch.Tensor, torch.Tensor]: """ Overview: Concatenate ``latent_state`` and ``action`` and use the dynamics network to predict ``next_latent_state`` @@ -437,145 +309,20 @@ def project(self, latent_state: torch.Tensor, with_grad=True) -> torch.Tensor: else: return proj.detach() - def get_params_mean(self) -> float: - return get_params_mean(self) - - -class DynamicsNetwork(nn.Module): - - def __init__( - self, - action_encoding_dim: int = 2, - num_channels: int = 64, - common_layer_num: int = 2, - fc_reward_layers: SequenceType = [32], - output_support_size: int = 601, - last_linear_layer_init_zero: bool = True, - activation: Optional[nn.Module] = nn.ReLU(inplace=True), - norm_type: Optional[str] = 'BN', - res_connection_in_dynamics: bool = False, - ): - """ - Overview: - The definition of dynamics network in Stochastic algorithm, which is used to predict next latent state - reward by the given current latent state and action. - The networks are mainly built on fully connected layers. - Arguments: - - action_encoding_dim (:obj:`int`): The dimension of action encoding. - - num_channels (:obj:`int`): The num of channels in latent states. - - common_layer_num (:obj:`int`): The number of common layers in dynamics network. - - fc_reward_layers (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). - - output_support_size (:obj:`int`): The size of categorical reward output. - - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of value/policy mlp, default sets it to True. - - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ - operation to speedup, e.g. ReLU(inplace=True). - - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. - - res_connection_in_dynamics (:obj:`bool`): Whether to use residual connection in dynamics network. - """ - super().__init__() - self.num_channels = num_channels - self.action_encoding_dim = action_encoding_dim - self.latent_state_dim = self.num_channels - self.action_encoding_dim - - self.res_connection_in_dynamics = res_connection_in_dynamics - if self.res_connection_in_dynamics: - self.fc_dynamics_1 = MLP( - in_channels=self.num_channels, - hidden_channels=self.latent_state_dim, - layer_num=common_layer_num, - out_channels=self.latent_state_dim, - activation=activation, - norm_type=norm_type, - output_activation=True, - output_norm=True, - # last_linear_layer_init_zero=False is important for convergence - last_linear_layer_init_zero=False, - ) - self.fc_dynamics_2 = MLP( - in_channels=self.latent_state_dim, - hidden_channels=self.latent_state_dim, - layer_num=common_layer_num, - out_channels=self.latent_state_dim, - activation=activation, - norm_type=norm_type, - output_activation=True, - output_norm=True, - # last_linear_layer_init_zero=False is important for convergence - last_linear_layer_init_zero=False, - ) - else: - self.fc_dynamics = MLP( - in_channels=self.num_channels, - hidden_channels=self.latent_state_dim, - layer_num=common_layer_num, - out_channels=self.latent_state_dim, - activation=activation, - norm_type=norm_type, - output_activation=True, - output_norm=True, - # last_linear_layer_init_zero=False is important for convergence - last_linear_layer_init_zero=False, - ) - self.fc_reward_head = MLP( - in_channels=self.latent_state_dim, - hidden_channels=fc_reward_layers[0], - layer_num=2, - out_channels=output_support_size, - activation=activation, - norm_type=norm_type, - output_activation=False, - output_norm=False, - last_linear_layer_init_zero=last_linear_layer_init_zero - ) - - def forward(self, state_action_encoding: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Overview: - Forward computation of the dynamics network. Predict the next latent state given current latent state and action. - Arguments: - - state_action_encoding (:obj:`torch.Tensor`): The state-action encoding, which is the concatenation of \ - latent state and action encoding, with shape (batch_size, num_channels, height, width). - Returns: - - next_latent_state (:obj:`torch.Tensor`): The next latent state, with shape (batch_size, latent_state_dim). - - reward (:obj:`torch.Tensor`): The predicted reward for input state. - """ - if self.res_connection_in_dynamics: - # take the state encoding (e.g. latent_state), - # state_action_encoding[:, -self.action_encoding_dim:] is action encoding - latent_state = state_action_encoding[:, :-self.action_encoding_dim] - x = self.fc_dynamics_1(state_action_encoding) - # the residual link: add the latent_state to the state_action encoding - next_latent_state = x + latent_state - next_latent_state_encoding = self.fc_dynamics_2(next_latent_state) - else: - next_latent_state = self.fc_dynamics(state_action_encoding) - next_latent_state_encoding = next_latent_state - - reward = self.fc_reward_head(next_latent_state_encoding) - - return next_latent_state, reward - - def get_dynamic_mean(self) -> float: - return get_dynamic_mean(self) - - def get_reward_mean(self) -> float: - return get_reward_mean(self) - - -class AfterstateDynamicsNetwork(nn.Module): +class AfterstateDynamicsNetwork(DynamicsNetwork): def __init__( - self, - action_encoding_dim: int = 2, - num_channels: int = 64, - common_layer_num: int = 2, - fc_reward_layers: SequenceType = [32], - output_support_size: int = 601, - last_linear_layer_init_zero: bool = True, - activation: Optional[nn.Module] = nn.ReLU(inplace=True), - norm_type: Optional[str] = 'BN', - res_connection_in_dynamics: bool = False, + self, + action_encoding_dim: int = 2, + num_channels: int = 64, + common_layer_num: int = 2, + fc_reward_layers: SequenceType = [32], + output_support_size: int = 601, + last_linear_layer_init_zero: bool = True, + activation: Optional[nn.Module] = nn.ReLU(inplace=True), + norm_type: Optional[str] = 'BN', + res_connection_in_dynamics: bool = False, ): """ Overview: @@ -594,62 +341,10 @@ def __init__( - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. - res_connection_in_dynamics (:obj:`bool`): Whether to use residual connection in dynamics network. """ - super().__init__() - self.num_channels = num_channels - self.action_encoding_dim = action_encoding_dim - self.latent_state_dim = self.num_channels - self.action_encoding_dim - - self.res_connection_in_dynamics = res_connection_in_dynamics - if self.res_connection_in_dynamics: - self.fc_dynamics_1 = MLP( - in_channels=self.num_channels, - hidden_channels=self.latent_state_dim, - layer_num=common_layer_num, - out_channels=self.latent_state_dim, - activation=activation, - norm_type=norm_type, - output_activation=True, - output_norm=True, - # last_linear_layer_init_zero=False is important for convergence - last_linear_layer_init_zero=False, - ) - self.fc_dynamics_2 = MLP( - in_channels=self.latent_state_dim, - hidden_channels=self.latent_state_dim, - layer_num=common_layer_num, - out_channels=self.latent_state_dim, - activation=activation, - norm_type=norm_type, - output_activation=True, - output_norm=True, - # last_linear_layer_init_zero=False is important for convergence - last_linear_layer_init_zero=False, - ) - else: - self.fc_dynamics = MLP( - in_channels=self.num_channels, - hidden_channels=self.latent_state_dim, - layer_num=common_layer_num, - out_channels=self.latent_state_dim, - activation=activation, - norm_type=norm_type, - output_activation=True, - output_norm=True, - # last_linear_layer_init_zero=False is important for convergence - last_linear_layer_init_zero=False, - ) - - self.fc_reward_head = MLP( - in_channels=self.latent_state_dim, - hidden_channels=fc_reward_layers[0], - layer_num=2, - out_channels=output_support_size, - activation=activation, - norm_type=norm_type, - output_activation=False, - output_norm=False, - last_linear_layer_init_zero=last_linear_layer_init_zero - ) + super(AfterstateDynamicsNetwork, self).__init__(action_encoding_dim, num_channels, common_layer_num, + fc_reward_layers, output_support_size, + last_linear_layer_init_zero + , activation, norm_type, res_connection_in_dynamics) def forward(self, state_action_encoding: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ @@ -662,30 +357,10 @@ def forward(self, state_action_encoding: torch.Tensor) -> Tuple[torch.Tensor, to - next_latent_state (:obj:`torch.Tensor`): The next latent state, with shape (batch_size, latent_state_dim). - reward (:obj:`torch.Tensor`): The predicted reward for input state. """ - if self.res_connection_in_dynamics: - # take the state encoding (e.g. latent_state), - # state_action_encoding[:, -self.action_encoding_dim:] is action encoding - latent_state = state_action_encoding[:, :-self.action_encoding_dim] - x = self.fc_dynamics_1(state_action_encoding) - # the residual link: add the latent_state to the state_action encoding - next_latent_state = x + latent_state - next_latent_state_encoding = self.fc_dynamics_2(next_latent_state) - else: - next_latent_state = self.fc_dynamics(state_action_encoding) - next_latent_state_encoding = next_latent_state - - reward = self.fc_reward_head(next_latent_state_encoding) + return super(AfterstateDynamicsNetwork, self).forward(state_action_encoding) - return next_latent_state, reward - def get_dynamic_mean(self) -> float: - return get_dynamic_mean(self) - - def get_reward_mean(self) -> float: - return get_reward_mean(self) - - -class AfterstatePredictionNetworkMLP(nn.Module): +class AfterstatePredictionNetworkMLP(PredictionNetworkMLP): def __init__( self, @@ -716,48 +391,10 @@ def __init__( operation to speedup, e.g. ReLU(inplace=True). - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. """ - super().__init__() - self.num_channels = num_channels - - # ******* common backbone ****** - self.fc_prediction_common = MLP( - in_channels=self.num_channels, - hidden_channels=self.num_channels, - out_channels=self.num_channels, - layer_num=common_layer_num, - activation=activation, - norm_type=norm_type, - output_activation=True, - output_norm=True, - # last_linear_layer_init_zero=False is important for convergence - last_linear_layer_init_zero=False, - ) - - # ******* value and policy head ****** - self.fc_value_head = MLP( - in_channels=self.num_channels, - hidden_channels=fc_value_layers[0], - out_channels=output_support_size, - layer_num=len(fc_value_layers) + 1, - activation=activation, - norm_type=norm_type, - output_activation=False, - output_norm=False, - # last_linear_layer_init_zero=True is beneficial for convergence speed. - last_linear_layer_init_zero=last_linear_layer_init_zero - ) - self.fc_policy_head = MLP( - in_channels=self.num_channels, - hidden_channels=fc_policy_layers[0], - out_channels=chance_space_size, - layer_num=len(fc_policy_layers) + 1, - activation=activation, - norm_type=norm_type, - output_activation=False, - output_norm=False, - # last_linear_layer_init_zero=True is beneficial for convergence speed. - last_linear_layer_init_zero=last_linear_layer_init_zero - ) + super(AfterstatePredictionNetworkMLP, self).__init__(chance_space_size, num_channels, common_layer_num, + fc_value_layers, fc_policy_layers, output_support_size, + last_linear_layer_init_zero + , activation, norm_type) def forward(self, latent_state: torch.Tensor): """ @@ -769,55 +406,4 @@ def forward(self, latent_state: torch.Tensor): - policy (:obj:`torch.Tensor`): policy tensor with shape (B, action_space_size). - value (:obj:`torch.Tensor`): value tensor with shape (B, output_support_size). """ - x_prediction_common = self.fc_prediction_common(latent_state) - - value = self.fc_value_head(x_prediction_common) - policy = self.fc_policy_head(x_prediction_common) - return policy, value - -class ChanceEncoderBackbone(nn.Module): - def __init__(self, input_dimensions, chance_encoding_dim=4): - super(ChanceEncoderBackbone, self).__init__() - self.fc1 = nn.Linear(input_dimensions, 128) - self.fc2 = nn.Linear(128, 64) - self.fc3 = nn.Linear(64, chance_encoding_dim) - - def forward(self, x): - x = torch.relu(self.fc1(x)) - x = torch.relu(self.fc2(x)) - x = self.fc3(x) - return x - - -class ChanceEncoder(nn.Module): - def __init__(self, input_dimensions, action_dimension): - super().__init__() - # Specify the action space for the model - self.action_space = action_dimension - # Define the encoder, which transforms observations into a latent space - self.encoder = ChanceEncoderBackbone(input_dimensions, action_dimension) - # Using the Straight Through Estimator method for backpropagation - self.onehot_argmax = StraightThroughEstimator() - - def forward(self, observations): - """ - Forward method for the ChanceEncoder. This method takes an observation - and applies the encoder to transform it to a latent space. Then applies the - StraightThroughEstimator to this encoding. - - References: - Planning in Stochastic Environments with a Learned Model (ICLR 2022), page 5, - Chance Outcomes section. - - Args: - observations (Tensor): Observation tensor. - - Returns: - chance_t (Tensor): Transformed tensor after applying one-hot argmax. - chance_encoding (Tensor): Encoding of the input observation tensor. - """ - # Apply the encoder to the observation - chance_encoding = self.encoder(observations) - # Apply one-hot argmax to the encoding - chance_onehot = self.onehot_argmax(chance_encoding) - return chance_encoding, chance_onehot \ No newline at end of file + return super(AfterstatePredictionNetworkMLP, self).forward(latent_state) diff --git a/lzero/policy/efficientzero.py b/lzero/policy/efficientzero.py index ca2fcedc2..6dcc97636 100644 --- a/lzero/policy/efficientzero.py +++ b/lzero/policy/efficientzero.py @@ -86,6 +86,8 @@ class EfficientZeroPolicy(Policy): # ****** observation ****** # (bool) Whether to transform image to string to save memory. transform2string=False, + # (bool) Whether to use gray scale image. + gray_scale=False, # (bool) Whether to use data augmentation. use_augmentation=False, # (list) The style of augmentation. @@ -204,6 +206,8 @@ def default_model(self) -> Tuple[str, List[str]]: return 'EfficientZeroModel', ['lzero.model.efficientzero_model'] elif self._cfg.model.model_type == "mlp": return 'EfficientZeroModelMLP', ['lzero.model.efficientzero_model_mlp'] + else: + raise ValueError("model type {} is not supported".format(self._cfg.model.model_type)) def _init_learn(self) -> None: """ diff --git a/lzero/policy/gumbel_muzero.py b/lzero/policy/gumbel_muzero.py index 43ff7ef5c..c0520efe4 100644 --- a/lzero/policy/gumbel_muzero.py +++ b/lzero/policy/gumbel_muzero.py @@ -87,6 +87,8 @@ class GumeblMuZeroPolicy(Policy): # ****** observation ****** # (bool) Whether to transform image to string to save memory. transform2string=False, + # (bool) Whether to use gray scale image. + gray_scale=False, # (bool) Whether to use data augmentation. use_augmentation=False, # (list) The style of augmentation. @@ -185,6 +187,8 @@ def default_model(self) -> Tuple[str, List[str]]: return 'MuZeroModel', ['lzero.model.muzero_model'] elif self._cfg.model.model_type == "mlp": return 'MuZeroModelMLP', ['lzero.model.muzero_model_mlp'] + else: + raise ValueError("model type {} is not supported".format(self._cfg.model.model_type)) def _init_learn(self) -> None: """ diff --git a/lzero/policy/muzero.py b/lzero/policy/muzero.py index 0717a84af..5aee5b82b 100644 --- a/lzero/policy/muzero.py +++ b/lzero/policy/muzero.py @@ -14,8 +14,7 @@ from lzero.mcts import MuZeroMCTSPtree as MCTSPtree from lzero.model import ImageTransforms from lzero.policy import scalar_transform, InverseScalarTransform, cross_entropy_loss, phi_transform, \ - DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, negative_cosine_similarity, prepare_obs, \ - configure_optimizers + DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, negative_cosine_similarity, prepare_obs @POLICY_REGISTRY.register('muzero') @@ -87,6 +86,8 @@ class MuZeroPolicy(Policy): # ****** observation ****** # (bool) Whether to transform image to string to save memory. transform2string=False, + # (bool) Whether to use gray scale image. + gray_scale=False, # (bool) Whether to use data augmentation. use_augmentation=False, # (list) The style of augmentation. @@ -108,7 +109,7 @@ class MuZeroPolicy(Policy): model_update_ratio=0.1, # (int) Minibatch size for one gradient descent. batch_size=256, - # (str) Optimizer for training policy network. ['SGD', 'Adam', 'AdamW'] + # (str) Optimizer for training policy network. ['SGD', 'Adam'] optim_type='SGD', # (float) Learning rate for training policy network. Initial lr for manually decay schedule. learning_rate=0.2, @@ -204,13 +205,15 @@ def default_model(self) -> Tuple[str, List[str]]: return 'MuZeroModel', ['lzero.model.muzero_model'] elif self._cfg.model.model_type == "mlp": return 'MuZeroModelMLP', ['lzero.model.muzero_model_mlp'] + else: + raise ValueError("model type {} is not supported".format(self._cfg.model.model_type)) def _init_learn(self) -> None: """ Overview: Learn mode init method. Called by ``self.__init__``. Initialize the learn model, optimizer and MCTS utils. """ - assert self._cfg.optim_type in ['SGD', 'Adam', 'AdamW_official', 'AdamW_nanoGPT'], self._cfg.optim_type + assert self._cfg.optim_type in ['SGD', 'Adam'], self._cfg.optim_type # NOTE: in board_gmaes, for fixed lr 0.003, 'Adam' is better than 'SGD'. if self._cfg.optim_type == 'SGD': self._optimizer = optim.SGD( @@ -223,17 +226,6 @@ def _init_learn(self) -> None: self._optimizer = optim.Adam( self._model.parameters(), lr=self._cfg.learning_rate, weight_decay=self._cfg.weight_decay ) - elif self._cfg.optim_type == 'AdamW': - self._optimizer = optim.AdamW( - self._model.parameters(), lr=self._cfg.learning_rate, weight_decay=self._cfg.weight_decay - ) - elif self._cfg.optim_type == 'AdamW_nanoGPT': - self._optimizer = configure_optimizers( - model=self._model, - weight_decay=self._cfg.weight_decay, - learning_rate=self._cfg.learning_rate, - device_type=self._cfg.device - ) if self._cfg.lr_piecewise_constant_decay: from torch.optim.lr_scheduler import LambdaLR diff --git a/lzero/policy/sampled_efficientzero.py b/lzero/policy/sampled_efficientzero.py index abe0e64c7..9691983cb 100644 --- a/lzero/policy/sampled_efficientzero.py +++ b/lzero/policy/sampled_efficientzero.py @@ -93,6 +93,8 @@ class SampledEfficientZeroPolicy(Policy): # ****** observation ****** # (bool) Whether to transform image to string to save memory. transform2string=False, + # (bool) Whether to use gray scale image. + gray_scale=False, # (bool) Whether to use data augmentation. use_augmentation=False, # (list) The style of augmentation. @@ -220,7 +222,8 @@ def default_model(self) -> Tuple[str, List[str]]: return 'SampledEfficientZeroModel', ['lzero.model.sampled_efficientzero_model'] elif self._cfg.model.model_type == "mlp": return 'SampledEfficientZeroModelMLP', ['lzero.model.sampled_efficientzero_model_mlp'] - + else: + raise ValueError("model type {} is not supported".format(self._cfg.model.model_type)) def _init_learn(self) -> None: """ Overview: diff --git a/lzero/policy/stochastic_muzero.py b/lzero/policy/stochastic_muzero.py index db53ad8d0..efaecb091 100644 --- a/lzero/policy/stochastic_muzero.py +++ b/lzero/policy/stochastic_muzero.py @@ -16,18 +16,17 @@ from lzero.policy import scalar_transform, InverseScalarTransform, cross_entropy_loss, phi_transform, \ DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, negative_cosine_similarity, prepare_obs, \ configure_optimizers -from lzero.policy.utils import calculate_topk_accuracy, plot_topk_accuracy, visualize_avg_softmax, \ - plot_argmax_distribution +from lzero.policy.utils import plot_topk_accuracy, visualize_avg_softmax, plot_argmax_distribution @POLICY_REGISTRY.register('stochastic_muzero') class StochasticMuZeroPolicy(Policy): """ Overview: - The policy class for Stochastic MuZero. + The policy class for Stochastic MuZero proposed in the paper https://openreview.net/pdf?id=X6D9bAHhBQ1. """ - # The default_config for MuZero policy. + # The default_config for Stochastic MuZero policy. config = dict( model=dict( # (str) The model type. For 1-dimensional vector obs, we use mlp model. For the image obs, we use conv model. @@ -83,6 +82,8 @@ class StochasticMuZeroPolicy(Policy): # ****** observation ****** # (bool) Whether to transform image to string to save memory. transform2string=False, + # (bool) Whether to use gray scale image. + gray_scale=False, # (bool) Whether to use data augmentation. use_augmentation=False, # (list) The style of augmentation. @@ -103,7 +104,7 @@ class StochasticMuZeroPolicy(Policy): model_update_ratio=0.1, # (int) Minibatch size for one gradient descent. batch_size=256, - # (str) Optimizer for training policy network. ['SGD', 'Adam', 'AdamW'] + # (str) Optimizer for training policy network. ['SGD', 'Adam'] optim_type='Adam', # (float) Learning rate for training policy network. Ininitial lr for manually decay schedule. learning_rate=int(3e-3), @@ -207,13 +208,15 @@ def default_model(self) -> Tuple[str, List[str]]: return 'StochasticMuZeroModel', ['lzero.model.stochastic_muzero_model'] elif self._cfg.model.model_type == "mlp": return 'StochasticMuZeroModelMLP', ['lzero.model.stochastic_muzero_model_mlp'] + else: + raise ValueError("model type {} is not supported".format(self._cfg.model.model_type)) def _init_learn(self) -> None: """ Overview: Learn mode init method. Called by ``self.__init__``. Ininitialize the learn model, optimizer and MCTS utils. """ - assert self._cfg.optim_type in ['SGD', 'Adam', 'AdamW', 'AdamW_nanoGPT'], self._cfg.optim_type + assert self._cfg.optim_type in ['SGD', 'Adam'], self._cfg.optim_type # NOTE: in board_gmaes, for fixed lr 0.003, 'Adam' is better than 'SGD'. if self._cfg.optim_type == 'SGD': self._optimizer = optim.SGD( @@ -226,17 +229,6 @@ def _init_learn(self) -> None: self._optimizer = optim.Adam( self._model.parameters(), lr=self._cfg.learning_rate, weight_decay=self._cfg.weight_decay ) - elif self._cfg.optim_type == 'AdamW': - self._optimizer = optim.AdamW( - self._model.parameters(), lr=self._cfg.learning_rate, weight_decay=self._cfg.weight_decay - ) - elif self._cfg.optim_type == 'AdamW_nanoGPT': - self._optimizer = configure_optimizers( - model=self._model, - weight_decay=self._cfg.weight_decay, - learning_rate=self._cfg.learning_rate, - device_type=self._cfg.device - ) if self._cfg.lr_piecewise_constant_decay: from torch.optim.lr_scheduler import LambdaLR @@ -315,8 +307,8 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in action_batch = torch.from_numpy(action_batch).to(self._cfg.device).unsqueeze(-1).long() data_list = [ mask_batch, - target_reward.astype('float64'), - target_value.astype('float64'), target_policy, weights + target_reward.astype('float32'), + target_value.astype('float32'), target_policy, weights ] [mask_batch, target_reward, target_value, target_policy, weights] = to_torch_float_tensor(data_list, self._cfg.device) @@ -464,6 +456,16 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in commitment_loss += torch.nn.MSELoss()(chance_encoding, true_chance_one_hot.float().detach()) else: afterstate_policy_loss += cross_entropy_loss(afterstate_policy_logits, chance_one_hot.detach()) + + if self._cfg.analyze_chance_distribution: + # visualize the avg softmax of afterstate_policy_logits + visualize_avg_softmax(afterstate_policy_logits) + # plot the argmax distribution of true_chance_one_hot + plot_argmax_distribution(true_chance_one_hot) + topK_values = range(1, self._cfg.model.chance_space_size+1) # top_K values from 1 to 32 + # calculate the topK accuracy of afterstate_policy_logits and plot the topK accuracy curve. + plot_topk_accuracy(afterstate_policy_logits, true_chance_one_hot, topK_values) + # TODO(pu): whether to detach the chance_one_hot in the commitment loss. # commitment_loss += torch.nn.MSELoss()(chance_encoding, chance_one_hot.float().detach()) commitment_loss += torch.nn.MSELoss()(chance_encoding, chance_one_hot.float()) diff --git a/lzero/policy/tests/test_utlis.py b/lzero/policy/tests/test_utlis.py index 12f18540c..2243d9284 100644 --- a/lzero/policy/tests/test_utlis.py +++ b/lzero/policy/tests/test_utlis.py @@ -1,34 +1,147 @@ +import numpy as np import pytest import torch -import numpy as np +import torch.nn.functional as F -from lzero.policy.utils import negative_cosine_similarity, to_torch_float_tensor +from lzero.policy.utils import negative_cosine_similarity, to_torch_float_tensor, visualize_avg_softmax, \ + calculate_topk_accuracy, plot_topk_accuracy, compare_argmax, plot_argmax_distribution +# We use the pytest.mark.unittest decorator to mark this class for unit testing. +@pytest.mark.unittest +class TestVisualizationFunctions: + + def test_visualize_avg_softmax(self): + """ + This test checks whether the visualize_avg_softmax function correctly + computes the average softmax probabilities and visualizes them. + """ + + # We initialize the input parameters. + batch_size = 256 + num_classes = 10 + logits = torch.randn(batch_size, num_classes) + + # We call the visualize_avg_softmax function. + visualize_avg_softmax(logits) + + # This function does not return anything, it only creates a plot. + # Therefore, we can only visually inspect the plot to check if it is correct. + + def test_calculate_topk_accuracy(self): + """ + This test checks whether the calculate_topk_accuracy function correctly + computes the top-k accuracy. + """ + + # We initialize the input parameters. + batch_size = 256 + num_classes = 10 + logits = torch.randn(batch_size, num_classes) + true_labels = torch.randint(0, num_classes, [batch_size]) + true_one_hot = F.one_hot(true_labels, num_classes) + top_k = 5 + + # We call the calculate_topk_accuracy function. + match_percentage = calculate_topk_accuracy(logits, true_one_hot, top_k) + + # We check if the match percentage is a float and within the range [0, 100]. + assert isinstance(match_percentage, float) + assert 0 <= match_percentage <= 100 + + def test_plot_topk_accuracy(self): + """ + This test checks whether the plot_topk_accuracy function correctly + plots the top-k accuracy for different values of k. + """ + + # We initialize the input parameters. + batch_size = 256 + num_classes = 10 + logits = torch.randn(batch_size, num_classes) + true_labels = torch.randint(0, num_classes, [batch_size]) + true_one_hot = F.one_hot(true_labels, num_classes) + top_k_values = range(1, 6) + + # We call the plot_topk_accuracy function. + plot_topk_accuracy(logits, true_one_hot, top_k_values) + + # This function does not return anything, it only creates a plot. + # Therefore, we can only visually inspect the plot to check if it is correct. + + def test_compare_argmax(self): + """ + This test checks whether the compare_argmax function correctly + plots the comparison of argmax values. + """ + + # We initialize the input parameters. + batch_size = 256 + num_classes = 10 + logits = torch.randn(batch_size, num_classes) + true_labels = torch.randint(0, num_classes, [batch_size]) + chance_one_hot = F.one_hot(true_labels, num_classes) + + # We call the compare_argmax function. + compare_argmax(logits, chance_one_hot) + + # This function does not return anything, it only creates a plot. + # Therefore, we can only visually inspect the plot to check if it is correct. + + def test_plot_argmax_distribution(self): + """ + This test checks whether the plot_argmax_distribution function correctly + plots the distribution of argmax values. + """ + + # We initialize the input parameters. + batch_size = 256 + num_classes = 10 + true_labels = torch.randint(0, num_classes, [batch_size]) + true_chance_one_hot = F.one_hot(true_labels, num_classes) + + # We call the plot_argmax_distribution function. + plot_argmax_distribution(true_chance_one_hot) + + # This function does not return anything, it only creates a plot. + # Therefore, we can only visually inspect the plot to check if it is correct. + + +# We use the pytest.mark.unittest decorator to mark this class for unit testing. @pytest.mark.unittest class TestUtils(): + # This function tests the negative_cosine_similarity function. + # This function computes the negative cosine similarity between two vectors. def test_negative_cosine_similarity(self): + # We initialize the input parameters. batch_size = 256 dim = 512 x1 = torch.randn(batch_size, dim) x2 = torch.randn(batch_size, dim) + + # We call the negative_cosine_similarity function. output = negative_cosine_similarity(x1, x2) + + # We check if the output shape is as expected. assert output.shape == (batch_size, ) + + # We check if all elements of the output are between -1 and 1. assert ((output >= -1) & (output <= 1)).all() + # We test a special case where the two input vectors are in the same direction. + # In this case, the cosine similarity should be -1. x1 = torch.randn(batch_size, dim) positive_factor = torch.randint(1, 100, [1]) output_positive = negative_cosine_similarity(x1, positive_factor.float() * x1) assert output_positive.shape == (batch_size, ) - # assert (output_negative == -1).all() # is not True, because of the numerical precision assert ((output_positive - (-1)) < 1e-6).all() + # We test another special case where the two input vectors are in opposite directions. + # In this case, the cosine similarity should be 1. negative_factor = -torch.randint(1, 100, [1]) output_negative = negative_cosine_similarity(x1, negative_factor.float() * x1) assert output_negative.shape == (batch_size, ) - # assert (output_negative == 1).all() - # assert (output_negative == 1).all() # is not True, because of the numerical precision assert ((output_positive - 1) < 1e-6).all() def test_to_torch_float_tensor(self): @@ -38,14 +151,14 @@ def test_to_torch_float_tensor(self): ), np.random.randn(4, 5), np.random.randn(4, 5), np.random.randn(4, 5), np.random.randn(4, 5) data_list_np = [ mask_batch_np, - target_value_prefix_np.astype('float64'), - target_value_np.astype('float64'), target_policy_np, weights_np + target_value_prefix_np.astype('float32'), + target_value_np.astype('float32'), target_policy_np, weights_np ] [mask_batch_func, target_value_prefix_func, target_value_func, target_policy_func, weights_func] = to_torch_float_tensor(data_list_np, device) mask_batch_2 = torch.from_numpy(mask_batch_np).to(device).float() - target_value_prefix_2 = torch.from_numpy(target_value_prefix_np.astype('float64')).to(device).float() - target_value_2 = torch.from_numpy(target_value_np.astype('float64')).to(device).float() + target_value_prefix_2 = torch.from_numpy(target_value_prefix_np.astype('float32')).to(device).float() + target_value_2 = torch.from_numpy(target_value_np.astype('float32')).to(device).float() target_policy_2 = torch.from_numpy(target_policy_np).to(device).float() weights_2 = torch.from_numpy(weights_np).to(device).float() diff --git a/zoo/atari/entry/atari_eval_config.py b/zoo/atari/entry/atari_eval.py similarity index 100% rename from zoo/atari/entry/atari_eval_config.py rename to zoo/atari/entry/atari_eval.py diff --git a/zoo/board_games/gomoku/entry/gomoku_alphazero_eval_config.py b/zoo/board_games/gomoku/entry/gomoku_alphazero_eval.py similarity index 100% rename from zoo/board_games/gomoku/entry/gomoku_alphazero_eval_config.py rename to zoo/board_games/gomoku/entry/gomoku_alphazero_eval.py diff --git a/zoo/board_games/gomoku/entry/gomoku_gumbel_muzero_eval_config.py b/zoo/board_games/gomoku/entry/gomoku_gumbel_muzero_eval.py similarity index 100% rename from zoo/board_games/gomoku/entry/gomoku_gumbel_muzero_eval_config.py rename to zoo/board_games/gomoku/entry/gomoku_gumbel_muzero_eval.py diff --git a/zoo/board_games/gomoku/entry/gomoku_muzero_eval_config.py b/zoo/board_games/gomoku/entry/gomoku_muzero_eval.py similarity index 100% rename from zoo/board_games/gomoku/entry/gomoku_muzero_eval_config.py rename to zoo/board_games/gomoku/entry/gomoku_muzero_eval.py diff --git a/zoo/board_games/tictactoe/entry/tictactoe_alphazero_eval_config.py b/zoo/board_games/tictactoe/entry/tictactoe_alphazero_eval.py similarity index 100% rename from zoo/board_games/tictactoe/entry/tictactoe_alphazero_eval_config.py rename to zoo/board_games/tictactoe/entry/tictactoe_alphazero_eval.py diff --git a/zoo/board_games/tictactoe/entry/tictactoe_muzero_eval_config.py b/zoo/board_games/tictactoe/entry/tictactoe_muzero_eval.py similarity index 100% rename from zoo/board_games/tictactoe/entry/tictactoe_muzero_eval_config.py rename to zoo/board_games/tictactoe/entry/tictactoe_muzero_eval.py diff --git a/zoo/box2d/bipedalwalker/entry/bipedalwalker_eval_config.py b/zoo/box2d/bipedalwalker/entry/bipedalwalker_eval.py similarity index 100% rename from zoo/box2d/bipedalwalker/entry/bipedalwalker_eval_config.py rename to zoo/box2d/bipedalwalker/entry/bipedalwalker_eval.py diff --git a/zoo/box2d/lunarlander/entry/lunarlander_eval_config.py b/zoo/box2d/lunarlander/entry/lunarlander_eval.py similarity index 100% rename from zoo/box2d/lunarlander/entry/lunarlander_eval_config.py rename to zoo/box2d/lunarlander/entry/lunarlander_eval.py diff --git a/zoo/classic_control/cartpole/entry/cartpole_eval_config.py b/zoo/classic_control/cartpole/entry/cartpole_eval.py similarity index 100% rename from zoo/classic_control/cartpole/entry/cartpole_eval_config.py rename to zoo/classic_control/cartpole/entry/cartpole_eval.py diff --git a/zoo/classic_control/pendulum/entry/pendulum_eval_config.py b/zoo/classic_control/pendulum/entry/pendulum_eval.py similarity index 100% rename from zoo/classic_control/pendulum/entry/pendulum_eval_config.py rename to zoo/classic_control/pendulum/entry/pendulum_eval.py diff --git a/zoo/game_2048/config/muzero_2048_config.py b/zoo/game_2048/config/muzero_2048_config.py index fa788fc3e..45eb66271 100644 --- a/zoo/game_2048/config/muzero_2048_config.py +++ b/zoo/game_2048/config/muzero_2048_config.py @@ -26,7 +26,7 @@ stop_value=int(1e6), env_name=env_name, obs_shape=(16, 4, 4), - obs_type='dict_observation', + obs_type='dict_encoded_board', raw_reward_type='raw', # 'merged_tiles_plus_log_max_tile_num' reward_normalize=False, reward_norm_scale=100, diff --git a/zoo/game_2048/config/stochastic_muzero_2048_config.py b/zoo/game_2048/config/stochastic_muzero_2048_config.py index 6ec45e1c7..d298d3031 100644 --- a/zoo/game_2048/config/stochastic_muzero_2048_config.py +++ b/zoo/game_2048/config/stochastic_muzero_2048_config.py @@ -27,7 +27,7 @@ stop_value=int(1e6), env_name=env_name, obs_shape=(16, 4, 4), - obs_type='dict_observation', + obs_type='dict_encoded_board', raw_reward_type='raw', # 'merged_tiles_plus_log_max_tile_num' reward_normalize=False, reward_scale=100, diff --git a/zoo/game_2048/entry/2048_bot_eval.py b/zoo/game_2048/entry/2048_bot_eval.py new file mode 100644 index 000000000..1ca2a042c --- /dev/null +++ b/zoo/game_2048/entry/2048_bot_eval.py @@ -0,0 +1,57 @@ +import numpy as np +from easydict import EasyDict +from rich import print + +from zoo.game_2048.envs.expectimax_search_based_bot import expectimax_search +from zoo.game_2048.envs.game_2048_env import Game2048Env + +# Define game configuration +config = EasyDict(dict( + env_name="game_2048", + save_replay=False, + replay_format='mp4', + replay_name_suffix='test', + replay_path=None, + render_real_time=False, + act_scale=True, + channel_last=True, + obs_type='raw_board', # options=['raw_board', 'raw_encoded_board', 'dict_encoded_board'] + reward_type='raw', # options=['raw', 'merged_tiles_plus_log_max_tile_num'] + reward_normalize=False, + reward_norm_scale=100, + max_tile=int(2 ** 16), + delay_reward_step=0, + prob_random_agent=0., + max_episode_steps=int(1e4), + is_collect=False, + ignore_legal_actions=True, + need_flatten=False, + num_of_possible_chance_tile=2, + possible_tiles=np.array([2, 4]), + tile_probabilities=np.array([0.9, 0.1]), +)) + +if __name__ == "__main__": + game_2048_env = Game2048Env(config) + obs = game_2048_env.reset() + print('init board state: ') + game_2048_env.render() + step = 0 + while True: + print('=' * 40) + grid = obs.astype(np.int64) + # action = game_2048_env.human_to_action() # which obtain about 10000 score + # action = game_2048_env.random_action() # which obtain about 1000 score + action = expectimax_search(grid) # which obtain about 58536 score + try: + obs, reward, done, info = game_2048_env.step(action) + except Exception as e: + print(f'Exception: {e}') + print('total_step_number: {}'.format(step)) + break + step += 1 + print(f"step: {step}, action: {action}, reward: {reward}, raw_reward: {info['raw_reward']}") + game_2048_env.render(mode='human') + if done: + print('total_step_number: {}'.format(step)) + break diff --git a/zoo/game_2048/entry/stochastic_muzero_2048_eval_config.py b/zoo/game_2048/entry/stochastic_muzero_2048_eval.py similarity index 86% rename from zoo/game_2048/entry/stochastic_muzero_2048_eval_config.py rename to zoo/game_2048/entry/stochastic_muzero_2048_eval.py index f3899e3d8..df8b0956f 100644 --- a/zoo/game_2048/entry/stochastic_muzero_2048_eval_config.py +++ b/zoo/game_2048/entry/stochastic_muzero_2048_eval.py @@ -1,17 +1,17 @@ # According to the model you want to evaluate, import the corresponding config. -from lzero.entry import eval_muzero import numpy as np +from lzero.entry import eval_muzero +from zoo.game_2048.config.stochastic_muzero_2048_config import main_config, create_config + if __name__ == "__main__": """ model_path (:obj:`Optional[str]`): The pretrained model path, which should point to the ckpt file of the pretrained model, and an absolute path is recommended. In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``. """ - # Take the config of sampled efficientzero as an example - from stochastic_muzero_2048_config import main_config, create_config - model_path = "/Users/puyuan/code/LightZero/data_stochastic_mz_ctree/game_2048_stochastic_muzero_ns100_upc200_rr0.0_bs512_chance-True-32_seed0/ckpt/ckpt_best.pth.tar" + model_path = "./ckpt/ckpt_best.pth.tar" returns_mean_seeds = [] returns_seeds = [] diff --git a/zoo/game_2048/entry/rule_based_2048_config.py b/zoo/game_2048/envs/expectimax_search_based_bot.py similarity index 72% rename from zoo/game_2048/entry/rule_based_2048_config.py rename to zoo/game_2048/envs/expectimax_search_based_bot.py index 7d48065ac..f9701e44e 100644 --- a/zoo/game_2048/entry/rule_based_2048_config.py +++ b/zoo/game_2048/envs/expectimax_search_based_bot.py @@ -2,17 +2,13 @@ from typing import Tuple, Union import numpy as np -from easydict import EasyDict -from rich import print -from zoo.game_2048.envs.game_2048_env import Game2048Env - -# Define rule-based search function -def rule_based_search(grid: np.array, fast_search: bool = True) -> int: +# Define expectimax search bot for 2048 env +def expectimax_search(grid: np.array, fast_search: bool = True) -> int: """ Overview: - Use Expectimax search algorithm to find the best action. + Use Expectimax search algorithm to find the best action for 2048 env. Adapted from https://github.com/xwjdsh/2048-ai/blob/master/ai/ai.go. """ # please refer to https://codemyroad.wordpress.com/2014/05/14/2048-ai-the-intelligent-bot/ @@ -149,56 +145,4 @@ def generate(grid: np.array) -> np.array: # set new number grid[empty[0][index], empty[1][index]] = number # return new grid - return grid - - -# Define game configuration -config = EasyDict(dict( - env_name="game_2048", - save_replay=False, - replay_format='mp4', - replay_name_suffix='test', - replay_path=None, - render_real_time=False, - act_scale=True, - channel_last=True, - obs_type='array', # options=['raw_observation', 'dict_observation', 'array'] - reward_type='raw', # options=['raw', 'merged_tiles_plus_log_max_tile_num'] - reward_normalize=False, - reward_norm_scale=100, - max_tile=int(2 ** 16), - delay_reward_step=0, - prob_random_agent=0., - max_episode_steps=int(1e4), - is_collect=False, - ignore_legal_actions=True, - need_flatten=False, - num_of_possible_chance_tile=2, - possible_tiles=np.array([2, 4]), - tile_probabilities=np.array([0.9, 0.1]), -)) - -if __name__ == "__main__": - game_2048_env = Game2048Env(config) - obs = game_2048_env.reset() - print('init board state: ') - game_2048_env.render() - step = 0 - while True: - print('=' * 40) - grid = obs.astype(np.int64) - # action = game_2048_env.human_to_action() # which obtain about 10000 score - # action = game_2048_env.random_action() # which obtain about 1000 score - action = rule_based_search(grid) # which obtain about 58536 score - try: - obs, reward, done, info = game_2048_env.step(action) - except Exception as e: - print(f'Exception: {e}') - print('total_step_number: {}'.format(step)) - break - step += 1 - print(f"step: {step}, action: {action}, reward: {reward}, raw_reward: {info['raw_reward']}") - game_2048_env.render(mode='human') - if done: - print('total_step_number: {}'.format(step)) - break + return grid \ No newline at end of file diff --git a/zoo/game_2048/envs/game_2048_env.py b/zoo/game_2048/envs/game_2048_env.py index 9884f1683..95c45a01b 100644 --- a/zoo/game_2048/envs/game_2048_env.py +++ b/zoo/game_2048/envs/game_2048_env.py @@ -19,6 +19,71 @@ @ENV_REGISTRY.register('game_2048') class Game2048Env(gym.Env): + """ + Overview: + The Game2048Env is a gym environment implementation of the 2048 game. The goal of the game is to slide numbered tiles + on a grid to combine them and create a tile with the number 2048 (or larger). The environment provides an interface to interact with + the game and receive observations, rewards, and game status information. + + Interfaces: + - reset(init_board=None, add_random_tile_flag=True): + Resets the game board and starts a new episode. It returns the initial observation of the game. + - step(action): + Advances the game by one step based on the provided action. It returns the new observation, reward, game status, + and additional information. + - render(mode='human'): + Renders the current state of the game for visualization purposes. + MDP Definition: + - Observation Space: + The observation space is a 4x4 grid representing the game board. Each cell in the grid can contain a number from + 0 to 2048. The observation can be in different formats based on the 'obs_type' parameter in the environment configuration. + - If 'obs_type' is set to 'encode_observation' (default): + The observation is a 3D numpy array of shape (4, 4, 16). Each cell in the array is represented as a one-hot vector + encoding the value of the tile in that cell. The one-hot vector has a length of 16, representing the possible tile + values from 0 to 2048. The first element in the one-hot vector corresponds to an empty cell (0 value). + - If 'obs_type' is set to 'dict_encoded_board': + The observation is a dictionary with the following keys: + - 'observation': A 3D numpy array representing the game board as described above. + - 'action_mask': A binary mask representing the legal actions that can be taken in the current state. + - 'to_play': A placeholder value (-1) indicating the current player (not applicable in this game). + - 'chance': A placeholder value representing the chance outcome (not applicable in this game). + - If 'obs_type' is set to 'raw_board': + The observation is the raw game board as a 2D numpy array of shape (4, 4). + - Action Space: + The action space is a discrete space with 4 possible actions: + - 0: Move Up + - 1: Move Right + - 2: Move Down + - 3: Move Left + - Reward: + The reward depends on the 'reward_type' parameter in the environment configuration. + - If 'reward_type' is set to 'raw': + The reward is a floating-point number representing the immediate reward obtained from the last action. + - If 'reward_type' is set to 'merged_tiles_plus_log_max_tile_num': + The reward is a floating-point number representing the number of merged tiles in the current step. + If the maximum tile number on the board after the step is greater than the previous maximum tile number, + the reward is further adjusted by adding the logarithm of the new maximum tile number multiplied by 0.1. + The reward is calculated as follows: reward = num_of_merged_tiles + (log2(new_max_tile_num) * 0.1) + If the new maximum tile number is the same as the previous maximum tile number, the reward does not + include the second term. Note: This reward type requires 'reward_normalize' to be set to False. + - Done: + The game ends when one of the following conditions is met: + - The maximum tile number (configured by 'max_tile') is reached. + - There are no legal moves left. + - The number of steps in the episode exceeds the maximum episode steps (configured by 'max_episode_steps'). + - Additional Information: + The 'info' dictionary returned by the 'step' method contains additional information about the current state. + The following keys are included in the dictionary: + - 'raw_reward': The raw reward obtained from the last action. + - 'current_max_tile_num': The current maximum tile number on the board. + - Rendering: + The 'render' method can be used to visualize the current state of the game. It supports two rendering modes: + - 'human': Renders the game in a text-based format in the console. + - 'rgb_array_render': Renders the game as an RGB image. + Note: The rendering mode is set to 'human' by default. + """ + + # The default_config for game 2048 env. config = dict( env_name="game_2048", save_replay=False, @@ -28,7 +93,7 @@ class Game2048Env(gym.Env): render_real_time=False, act_scale=True, channel_last=True, - obs_type='raw_observation', # options=['raw_observation', 'dict_observation', 'array'] + obs_type='dict_encoded_board', # options=['raw_board', 'raw_encoded_board', 'dict_encoded_board'] reward_normalize=False, reward_norm_scale=100, reward_type='raw', # options=['raw', 'merged_tiles_plus_log_max_tile_num'] @@ -43,7 +108,6 @@ class Game2048Env(gym.Env): possible_tiles=np.array([2, 4]), tile_probabilities=np.array([0.9, 0.1]), ) - metadata = {'render.modes': ['human', 'rgb_array_render']} @classmethod def default_config(cls: type) -> EasyDict: @@ -69,7 +133,7 @@ def __init__(self, cfg: dict) -> None: self.reward_norm_scale = cfg.reward_norm_scale assert self.reward_type in ['raw', 'merged_tiles_plus_log_max_tile_num'] assert self.reward_type == 'raw' or ( - self.reward_type == 'merged_tiles_plus_log_max_tile_num' and self.reward_normalize == False) + self.reward_type == 'merged_tiles_plus_log_max_tile_num' and self.reward_normalize is False) self.max_tile = cfg.max_tile # Define the maximum tile that will end the game (e.g. 2048). None means no limit. # This does not affect the state returned. @@ -143,15 +207,19 @@ def reset(self, init_board=None, add_random_tile_flag=True): observation = observation.reshape(-1) # Based on the observation type, create the appropriate observation object - if self.obs_type == 'dict_observation': + if self.obs_type == 'dict_encoded_board': observation = { 'observation': observation, 'action_mask': action_mask, 'to_play': -1, 'chance': self.chance } - elif self.obs_type == 'array': + elif self.obs_type == 'raw_board': observation = self.board + elif self.obs_type == 'raw_encoded_board': + observation = observation + else: + raise NotImplementedError # Render the game if the replay is to be saved if self.save_replay: @@ -165,7 +233,7 @@ def step(self, action): Perform one step of the game. This involves making a move, adding a new tile, and updating the game state. New tile could be added randomly or from the tile probabilities. The rewards are calculated based on the game configuration ('merged_tiles_plus_log_max_tile_num' or 'raw'). - The observations are also returned based on the game configuration ('dict_observation', 'array', or 'raw'). + The observations are also returned based on the game configuration ('raw_board', 'raw_encoded_board' or 'dict_encoded_board'). Arguments: - action (:obj:`int`): The action to be performed. Returns: @@ -230,12 +298,14 @@ def step(self, action): action_mask[self.legal_actions] = 1 # Return the observation based on the observation type - if self.obs_type == 'dict_observation': + if self.obs_type == 'dict_encoded_board': observation = {'observation': observation, 'action_mask': action_mask, 'to_play': -1, 'chance': self.chance} - elif self.obs_type == 'array': + elif self.obs_type == 'raw_board': observation = self.board - else: + elif self.obs_type == 'raw_encoded_board': observation = observation + else: + raise NotImplementedError # Normalize the reward if necessary if self.reward_normalize: From 1258be5e3e9b460dfdc4438bf3a86d880c8c6211 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= <2402552459@qq.com> Date: Mon, 11 Sep 2023 00:23:04 +0800 Subject: [PATCH 26/28] fix(pu): fix test_game_segment.py --- lzero/mcts/tests/atari_efficientzero_config_test.py | 3 ++- lzero/mcts/tests/tictactoe_muzero_bot_mode_config_test.py | 3 ++- lzero/policy/stochastic_muzero.py | 6 +----- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/lzero/mcts/tests/atari_efficientzero_config_test.py b/lzero/mcts/tests/atari_efficientzero_config_test.py index 5d3afd7ad..1da966cc9 100644 --- a/lzero/mcts/tests/atari_efficientzero_config_test.py +++ b/lzero/mcts/tests/atari_efficientzero_config_test.py @@ -63,6 +63,8 @@ ), cuda=True, env_type='not_board_games', + transform2string=False, + gray_scale=False, game_segment_length=400, use_augmentation=True, num_simulations=num_simulations, @@ -80,7 +82,6 @@ collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, discount_factor=0.997, - transform2string=False, lstm_horizon_len=5, use_ture_chance_label_in_chance_encoder=False, ), diff --git a/lzero/mcts/tests/tictactoe_muzero_bot_mode_config_test.py b/lzero/mcts/tests/tictactoe_muzero_bot_mode_config_test.py index 6369a2ea3..eb8eef258 100644 --- a/lzero/mcts/tests/tictactoe_muzero_bot_mode_config_test.py +++ b/lzero/mcts/tests/tictactoe_muzero_bot_mode_config_test.py @@ -52,6 +52,8 @@ ), cuda=True, env_type='board_games', + transform2string=False, + gray_scale=False, update_per_collect=update_per_collect, batch_size=batch_size, optim_type='Adam', @@ -71,7 +73,6 @@ replay_buffer_size=int(3e3), collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, - transform2string=False, lstm_horizon_len=5, use_ture_chance_label_in_chance_encoder=False, ), diff --git a/lzero/policy/stochastic_muzero.py b/lzero/policy/stochastic_muzero.py index efaecb091..c0182632c 100644 --- a/lzero/policy/stochastic_muzero.py +++ b/lzero/policy/stochastic_muzero.py @@ -466,17 +466,13 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # calculate the topK accuracy of afterstate_policy_logits and plot the topK accuracy curve. plot_topk_accuracy(afterstate_policy_logits, true_chance_one_hot, topK_values) - # TODO(pu): whether to detach the chance_one_hot in the commitment loss. - # commitment_loss += torch.nn.MSELoss()(chance_encoding, chance_one_hot.float().detach()) + # TODO(pu): whether to detach the chance_encoding in the commitment loss. commitment_loss += torch.nn.MSELoss()(chance_encoding, chance_one_hot.float()) afterstate_value_loss += cross_entropy_loss(afterstate_value, target_value_categorical[:, step_i]) value_loss += cross_entropy_loss(value, target_value_categorical[:, step_i + 1]) reward_loss += cross_entropy_loss(reward, target_reward_categorical[:, step_i]) - # Follow MuZero, set half gradient - # latent_state.register_hook(lambda grad: grad * 0.5) - if self._cfg.monitor_extra_statistics: original_rewards = self.inverse_scalar_transform_handle(reward) original_rewards_cpu = original_rewards.detach().cpu() From 9e5b3d8e4827fab724750eb2338e6e73f0d791a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= <2402552459@qq.com> Date: Tue, 12 Sep 2023 15:17:32 +0800 Subject: [PATCH 27/28] polish(pu): polish comments, abstract a get_target_obs_index_in_step_k method and add its unittest --- lzero/mcts/tests/__init__.py | 0 lzero/mcts/tests/config/__init__.py | 0 .../atari_efficientzero_config_for_test.py} | 0 ...tactoe_muzero_bot_mode_config_for_test.py} | 0 lzero/mcts/tests/test_game_segment.py | 4 +- lzero/model/stochastic_muzero_model.py | 204 ++++++------------ lzero/model/stochastic_muzero_model_mlp.py | 83 ++----- lzero/policy/alphazero.py | 8 +- lzero/policy/efficientzero.py | 17 +- lzero/policy/gumbel_muzero.py | 22 +- lzero/policy/muzero.py | 81 ++++--- lzero/policy/random_policy.py | 9 +- lzero/policy/sampled_efficientzero.py | 20 +- lzero/policy/stochastic_muzero.py | 108 ++++------ lzero/policy/tests/config/__init__.py | 0 .../config/atari_muzero_config_for_test.py | 98 +++++++++ .../config/cartpole_muzero_config_for_test.py | 74 +++++++ .../test_get_target_obs_index_in_step_k.py | 73 +++++++ .../cartpole_stochastic_muzero_config.py | 3 +- .../config/stochastic_muzero_2048_config.py | 9 +- zoo/game_2048/entry/2048_bot_eval.py | 13 +- ...astic_muzero_2048_eval.py => 2048_eval.py} | 7 +- zoo/game_2048/envs/game_2048_env.py | 34 ++- 23 files changed, 491 insertions(+), 376 deletions(-) create mode 100644 lzero/mcts/tests/__init__.py create mode 100644 lzero/mcts/tests/config/__init__.py rename lzero/mcts/tests/{atari_efficientzero_config_test.py => config/atari_efficientzero_config_for_test.py} (100%) rename lzero/mcts/tests/{tictactoe_muzero_bot_mode_config_test.py => config/tictactoe_muzero_bot_mode_config_for_test.py} (100%) create mode 100644 lzero/policy/tests/config/__init__.py create mode 100644 lzero/policy/tests/config/atari_muzero_config_for_test.py create mode 100644 lzero/policy/tests/config/cartpole_muzero_config_for_test.py create mode 100644 lzero/policy/tests/test_get_target_obs_index_in_step_k.py rename zoo/game_2048/entry/{stochastic_muzero_2048_eval.py => 2048_eval.py} (89%) diff --git a/lzero/mcts/tests/__init__.py b/lzero/mcts/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lzero/mcts/tests/config/__init__.py b/lzero/mcts/tests/config/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lzero/mcts/tests/atari_efficientzero_config_test.py b/lzero/mcts/tests/config/atari_efficientzero_config_for_test.py similarity index 100% rename from lzero/mcts/tests/atari_efficientzero_config_test.py rename to lzero/mcts/tests/config/atari_efficientzero_config_for_test.py diff --git a/lzero/mcts/tests/tictactoe_muzero_bot_mode_config_test.py b/lzero/mcts/tests/config/tictactoe_muzero_bot_mode_config_for_test.py similarity index 100% rename from lzero/mcts/tests/tictactoe_muzero_bot_mode_config_test.py rename to lzero/mcts/tests/config/tictactoe_muzero_bot_mode_config_for_test.py diff --git a/lzero/mcts/tests/test_game_segment.py b/lzero/mcts/tests/test_game_segment.py index 26d03e421..492e00c69 100644 --- a/lzero/mcts/tests/test_game_segment.py +++ b/lzero/mcts/tests/test_game_segment.py @@ -17,14 +17,14 @@ def test_game_segment(test_algo): if test_algo == 'EfficientZero': from lzero.mcts.tree_search.mcts_ctree import EfficientZeroMCTSCtree as MCTSCtree from lzero.model.efficientzero_model import EfficientZeroModel as Model - from lzero.mcts.tests.atari_efficientzero_config_test import atari_efficientzero_config as config + from lzero.mcts.tests.config.atari_efficientzero_config_for_test import atari_efficientzero_config as config from zoo.atari.envs.atari_lightzero_env import AtariLightZeroEnv envs = [AtariLightZeroEnv(config.env) for _ in range(config.env.evaluator_env_num)] elif test_algo == 'MuZero': from lzero.mcts.tree_search.mcts_ctree import MuZeroMCTSCtree as MCTSCtree from lzero.model.muzero_model import MuZeroModel as Model - from lzero.mcts.tests.tictactoe_muzero_bot_mode_config_test import tictactoe_muzero_config as config + from lzero.mcts.tests.config.tictactoe_muzero_bot_mode_config_for_test import tictactoe_muzero_config as config from zoo.board_games.tictactoe.envs.tictactoe_env import TicTacToeEnv envs = [TicTacToeEnv(config.env) for _ in range(config.env.evaluator_env_num)] diff --git a/lzero/model/stochastic_muzero_model.py b/lzero/model/stochastic_muzero_model.py index b6e6b7765..ac0614bf8 100644 --- a/lzero/model/stochastic_muzero_model.py +++ b/lzero/model/stochastic_muzero_model.py @@ -579,103 +579,8 @@ def get_reward_mean(self) -> float: return get_reward_mean(self) -class AfterstateDynamicsNetwork(nn.Module): - - def __init__( - self, - num_res_blocks: int, - num_channels: int, - reward_head_channels: int, - fc_reward_layers: SequenceType, - output_support_size: int, - flatten_output_size_for_reward_head: int, - last_linear_layer_init_zero: bool = True, - activation: Optional[nn.Module] = nn.ReLU(inplace=True), - ): - """ - Overview: - The definition of afterstate dynamics network in Stochastic MuZero algorithm, which is used to predict next afterstate given current latent state and action. - Arguments: - - num_res_blocks (:obj:`int`): The number of res blocks in AlphaZero model. - - num_channels (:obj:`int`): The channels of input, including obs and action encoding. - - reward_head_channels (:obj:`int`): The channels of reward head. - - fc_reward_layers (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). - - output_support_size (:obj:`int`): The size of categorical reward output. - - flatten_output_size_for_reward_head (:obj:`int`): The flatten size of output for reward head, i.e., \ - the input size of reward head. - - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initialization for the last layer of \ - reward mlp, default set it to True. - - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ - operation to speedup, e.g. ReLU(inplace=True). - """ - super().__init__() - self.num_channels = num_channels - self.flatten_output_size_for_reward_head = flatten_output_size_for_reward_head - self.conv = nn.Conv2d(num_channels, num_channels - 1, kernel_size=3, stride=1, padding=1, bias=False) - self.bn = nn.BatchNorm2d(num_channels - 1) - self.resblocks = nn.ModuleList( - [ - ResBlock( - in_channels=num_channels - 1, activation=activation, norm_type='BN', res_type='basic', bias=False - ) for _ in range(num_res_blocks) - ] - ) - self.conv1x1_reward = nn.Conv2d(num_channels - 1, reward_head_channels, 1) - self.bn_reward = nn.BatchNorm2d(reward_head_channels) - self.fc_reward_head = MLP( - self.flatten_output_size_for_reward_head, - hidden_channels=fc_reward_layers[0], - layer_num=len(fc_reward_layers) + 1, - out_channels=output_support_size, - activation=activation, - norm_type='BN', - output_activation=False, - output_norm=False, - last_linear_layer_init_zero=last_linear_layer_init_zero - ) - self.activation = activation - - def forward(self, state_action_encoding: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Overview: - Forward computation of the dynamics network. Predict next latent state given current latent state and action. - Arguments: - - state_action_encoding (:obj:`torch.Tensor`): The state-action encoding, which is the concatenation of \ - latent state and action encoding, with shape (batch_size, num_channels, height, width). - Returns: - - next_latent_state (:obj:`torch.Tensor`): The next latent state, with shape (batch_size, num_channels, \ - height, width). - - reward (:obj:`torch.Tensor`): The predicted reward, with shape (batch_size, output_support_size). - """ - # take the state encoding (afterstate), state_action_encoding[:, -1, :, :] is action encoding - afterstate = state_action_encoding[:, :-1, :, :] - x = self.conv(state_action_encoding) - x = self.bn(x) - - # the residual link: add state encoding to the state_action encoding - x += afterstate - x = self.activation(x) - - for block in self.resblocks: - x = block(x) - afterstate = x - # reward = None - - x = self.conv1x1_reward(afterstate) - x = self.bn_reward(x) - x = self.activation(x) - x = x.view(-1, self.flatten_output_size_for_reward_head) - - # use the fully connected layer to predict reward - reward = self.fc_reward_head(x) - - return afterstate, reward - - def get_dynamic_mean(self) -> float: - return get_dynamic_mean(self) - - def get_reward_mean(self) -> float: - return get_reward_mean(self) +# TODO(pu): customize different afterstate dynamics network +AfterstateDynamicsNetwork = DynamicsNetwork class AfterstatePredictionNetwork(nn.Module): @@ -758,7 +663,7 @@ def __init__( def forward(self, afterstate: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Overview: - Forward computation of the prediction network. + Forward computation of the afterstate prediction network. Arguments: - afterstate (:obj:`torch.Tensor`): input tensor with shape (B, afterstate_dim). Returns: @@ -785,11 +690,20 @@ def forward(self, afterstate: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor] class ChanceEncoderBackbone(nn.Module): - def __init__(self, observation_space_dimensions, chance_encoding_dim=4): + """ + Overview: + The definition of chance encoder backbone network, \ + which is used to encode the (image) observation into a latent space. + Arguments: + - input_dimensions (:obj:`tuple`): The dimension of observation space. + - chance_encoding_dim (:obj:`int`): The dimension of chance encoding. + """ + + def __init__(self, input_dimensions, chance_encoding_dim=4): super(ChanceEncoderBackbone, self).__init__() - self.conv1 = nn.Conv2d(observation_space_dimensions[0] * 2, 32, 3, padding=1) + self.conv1 = nn.Conv2d(input_dimensions[0] * 2, 32, 3, padding=1) self.conv2 = nn.Conv2d(32, 64, 3, padding=1) - self.fc1 = nn.Linear(64 * observation_space_dimensions[1] * observation_space_dimensions[2], 128) + self.fc1 = nn.Linear(64 * input_dimensions[1] * input_dimensions[2], 128) self.fc2 = nn.Linear(128, 64) self.fc3 = nn.Linear(64, chance_encoding_dim) @@ -804,6 +718,15 @@ def forward(self, x): class ChanceEncoderBackboneMLP(nn.Module): + """ + Overview: + The definition of chance encoder backbone network, \ + which is used to encode the (vector) observation into a latent space. + Arguments: + - input_dimensions (:obj:`tuple`): The dimension of observation space. + - chance_encoding_dim (:obj:`int`): The dimension of chance encoding. + """ + def __init__(self, input_dimensions, chance_encoding_dim=4): super(ChanceEncoderBackboneMLP, self).__init__() self.fc1 = nn.Linear(input_dimensions, 128) @@ -818,15 +741,16 @@ def forward(self, x): class ChanceEncoder(nn.Module): + def __init__(self, input_dimensions, action_dimension, encoder_backbone_type='conv'): super().__init__() # Specify the action space for the model self.action_space = action_dimension if encoder_backbone_type == 'conv': # Define the encoder, which transforms observations into a latent space - self.chance_encoder = ChanceEncoderBackbone(input_dimensions, action_dimension) + self.encoder = ChanceEncoderBackbone(input_dimensions, action_dimension) elif encoder_backbone_type == 'mlp': - self.chance_encoder = ChanceEncoderBackboneMLP(input_dimensions, action_dimension) + self.encoder = ChanceEncoderBackboneMLP(input_dimensions, action_dimension) else: raise ValueError('Encoder backbone type not supported') @@ -835,23 +759,21 @@ def __init__(self, input_dimensions, action_dimension, encoder_backbone_type='co def forward(self, observations): """ - Forward method for the ChanceEncoder. This method takes an observation - and applies the encoder to transform it to a latent space. Then applies the - StraightThroughEstimator to this encoding. - - References: - Planning in Stochastic Environments with a Learned Model (ICLR 2022), page 5, - Chance Outcomes section. + Overview: + Forward method for the ChanceEncoder. This method takes an observation \ + and applies the encoder to transform it to a latent space. Then applies the \ + StraightThroughEstimator to this encoding. \ + References: Planning in Stochastic Environments with a Learned Model (ICLR 2022), page 5, + Chance Outcomes section. Arguments: - observations (Tensor): Observation tensor. - + - observations (:obj:`torch.Tensor`): Observation tensor. Returns: - chance (Tensor): Transformed tensor after applying one-hot argmax. - chance_encoding (Tensor): Encoding of the input observation tensor. + - chance (:obj:`torch.Tensor`): Transformed tensor after applying one-hot argmax. + - chance_encoding (:obj:`torch.Tensor`): Encoding of the input observation tensor. """ # Apply the encoder to the observation - chance_encoding = self.chance_encoder(observations) + chance_encoding = self.encoder(observations) # Apply one-hot argmax to the encoding chance_onehot = self.onehot_argmax(chance_encoding) return chance_encoding, chance_onehot @@ -863,14 +785,13 @@ def __init__(self): def forward(self, x): """ - Forward method for the StraightThroughEstimator. This applies the one-hot argmax - function to the input tensor. - + Overview: + Forward method for the StraightThroughEstimator. This applies the one-hot argmax \ + function to the input tensor. Arguments: - x (Tensor): Input tensor. - + - x (:obj:`torch.Tensor`): Input tensor. Returns: - Tensor: Transformed tensor after applying one-hot argmax. + - (:obj:`torch.Tensor`): Transformed tensor after applying one-hot argmax. """ # Apply one-hot argmax to the input x = OnehotArgmax.apply(x) @@ -879,28 +800,28 @@ def forward(self, x): class OnehotArgmax(torch.autograd.Function): """ - Custom PyTorch function for one-hot argmax. This function transforms the input tensor - into a one-hot tensor where the index with the maximum value in the original tensor is - set to 1 and all other indices are set to 0. It allows gradients to flow to the encoder - during backpropagation. - - For more information, refer to: - https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html + Overview: + Custom PyTorch function for one-hot argmax. This function transforms the input tensor \ + into a one-hot tensor where the index with the maximum value in the original tensor is \ + set to 1 and all other indices are set to 0. It allows gradients to flow to the encoder \ + during backpropagation. + + For more information, refer to: \ + https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html """ @staticmethod def forward(ctx, input): """ - Forward method for the one-hot argmax function. This method transforms the input - tensor into a one-hot tensor. - + Overview: + Forward method for the one-hot argmax function. This method transforms the input \ + tensor into a one-hot tensor. Arguments: - ctx (context): A context object that can be used to stash information for + - ctx (:obj:`context`): A context object that can be used to stash information for backward computation. - input (Tensor): Input tensor. - + - input (:obj:`torch.Tensor`): Input tensor. Returns: - Tensor: One-hot tensor. + - (:obj:`torch.Tensor`): One-hot tensor. """ # Transform the input tensor to a one-hot tensor return torch.zeros_like(input).scatter_(-1, torch.argmax(input, dim=-1, keepdim=True), 1.) @@ -908,14 +829,13 @@ def forward(ctx, input): @staticmethod def backward(ctx, grad_output): """ - Backward method for the one-hot argmax function. This method allows gradients - to flow to the encoder during backpropagation. - + Overview: + Backward method for the one-hot argmax function. This method allows gradients \ + to flow to the encoder during backpropagation. Arguments: - ctx (context): A context object that was stashed in the forward pass. - grad_output (Tensor): The gradient of the output tensor. - + - ctx (:obj:`context`): A context object that was stashed in the forward pass. + - grad_output (:obj:`torch.Tensor`): The gradient of the output tensor. Returns: - Tensor: The gradient of the input tensor. + - (:obj:`torch.Tensor`): The gradient of the input tensor. """ return grad_output diff --git a/lzero/model/stochastic_muzero_model_mlp.py b/lzero/model/stochastic_muzero_model_mlp.py index f9575820e..7c958e2ea 100644 --- a/lzero/model/stochastic_muzero_model_mlp.py +++ b/lzero/model/stochastic_muzero_model_mlp.py @@ -41,10 +41,10 @@ def __init__( ): """ Overview: - The definition of the network model of Stochastic, which is a generalization version for 1D vector obs. - The networks are mainly built on fully connected layers. - The representation network is an MLP network which maps the raw observation to a latent state. - The dynamics network is an MLP network which predicts the next latent state, and reward given the current latent state and action. + The definition of the network model of Stochastic, which is a generalization version for 1D vector obs. \ + The networks are mainly built on fully connected layers. \ + The representation network is an MLP network which maps the raw observation to a latent state. \ + The dynamics network is an MLP network which predicts the next latent state, and reward given the current latent state and action. \ The prediction network is an MLP network which predicts the value and policy given the current latent state. Arguments: - observation_shape (:obj:`int`): Observation space shape, e.g. 8 for Lunarlander. @@ -110,8 +110,8 @@ def __init__( observation_shape=observation_shape, hidden_channels=self.latent_state_dim, norm_type=norm_type ) # TODO(pu): different input data type for chance_encoder - self.chance_encoder = ChanceEncoder(observation_shape * 2, chance_space_size, - encoder_backbone_type='mlp') # input is two concatenated frames + # here, the input is two concatenated frames + self.chance_encoder = ChanceEncoder(observation_shape * 2, chance_space_size, encoder_backbone_type='mlp') self.dynamics_network = DynamicsNetwork( action_encoding_dim=self.action_encoding_dim, num_channels=self.latent_state_dim + self.action_encoding_dim, @@ -172,7 +172,7 @@ def __init__( def _dynamics(self, latent_state: torch.Tensor, action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Overview: - Concatenate ``latent_state`` and ``action`` and use the dynamics network to predict ``next_latent_state`` + Concatenate ``latent_state`` and ``action`` and use the dynamics network to predict ``next_latent_state`` \ ``reward`` and ``next_reward_hidden_state``. Arguments: - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. @@ -229,7 +229,7 @@ def _afterstate_dynamics(self, latent_state: torch.Tensor, action: torch.Tensor) torch.Tensor, torch.Tensor]: """ Overview: - Concatenate ``latent_state`` and ``action`` and use the dynamics network to predict ``next_latent_state`` + Concatenate ``latent_state`` and ``action`` and use the dynamics network to predict ``next_latent_state`` \ ``reward`` and ``next_reward_hidden_state``. Arguments: - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. @@ -285,8 +285,8 @@ def _afterstate_dynamics(self, latent_state: torch.Tensor, action: torch.Tensor) def project(self, latent_state: torch.Tensor, with_grad=True) -> torch.Tensor: """ Overview: - Project the latent state to a lower dimension to calculate the self-supervised loss, which is proposed in EfficientZero. - For more details, please refer to the paper ``Exploring Simple Siamese Representation Learning``. + Project the latent state to a lower dimension to calculate the self-supervised loss, which is \ + proposed in EfficientZero. For more details, please refer to the paper ``Exploring Simple Siamese Representation Learning``. Arguments: - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. - with_grad (:obj:`bool`): Whether to calculate gradient for the projection result. @@ -310,54 +310,7 @@ def project(self, latent_state: torch.Tensor, with_grad=True) -> torch.Tensor: return proj.detach() -class AfterstateDynamicsNetwork(DynamicsNetwork): - - def __init__( - self, - action_encoding_dim: int = 2, - num_channels: int = 64, - common_layer_num: int = 2, - fc_reward_layers: SequenceType = [32], - output_support_size: int = 601, - last_linear_layer_init_zero: bool = True, - activation: Optional[nn.Module] = nn.ReLU(inplace=True), - norm_type: Optional[str] = 'BN', - res_connection_in_dynamics: bool = False, - ): - """ - Overview: - The definition of dynamics network in Stochastic algorithm, which is used to predict next latent state - reward by the given current latent state and action. - The networks are mainly built on fully connected layers. - Arguments: - - action_encoding_dim (:obj:`int`): The dimension of action encoding. - - num_channels (:obj:`int`): The num of channels in latent states. - - common_layer_num (:obj:`int`): The number of common layers in dynamics network. - - fc_reward_layers (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). - - output_support_size (:obj:`int`): The size of categorical reward output. - - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of value/policy mlp, default sets it to True. - - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ - operation to speedup, e.g. ReLU(inplace=True). - - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. - - res_connection_in_dynamics (:obj:`bool`): Whether to use residual connection in dynamics network. - """ - super(AfterstateDynamicsNetwork, self).__init__(action_encoding_dim, num_channels, common_layer_num, - fc_reward_layers, output_support_size, - last_linear_layer_init_zero - , activation, norm_type, res_connection_in_dynamics) - - def forward(self, state_action_encoding: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Overview: - Forward computation of the dynamics network. Predict the next latent state given current latent state and action. - Arguments: - - state_action_encoding (:obj:`torch.Tensor`): The state-action encoding, which is the concatenation of \ - latent state and action encoding, with shape (batch_size, num_channels, height, width). - Returns: - - next_latent_state (:obj:`torch.Tensor`): The next latent state, with shape (batch_size, latent_state_dim). - - reward (:obj:`torch.Tensor`): The predicted reward for input state. - """ - return super(AfterstateDynamicsNetwork, self).forward(state_action_encoding) +AfterstateDynamicsNetwork = DynamicsNetwork class AfterstatePredictionNetworkMLP(PredictionNetworkMLP): @@ -376,7 +329,7 @@ def __init__( ): """ Overview: - The definition of policy and value prediction network with Multi-Layer Perceptron (MLP), + The definition of policy and value prediction network with Multi-Layer Perceptron (MLP), \ which is used to predict value and policy by the given latent state. Arguments: - chance_space_size: (:obj:`int`): Chance space size, usually an integer number. For discrete action \ @@ -395,15 +348,3 @@ def __init__( fc_value_layers, fc_policy_layers, output_support_size, last_linear_layer_init_zero , activation, norm_type) - - def forward(self, latent_state: torch.Tensor): - """ - Overview: - Forward computation of the prediction network. - Arguments: - - latent_state (:obj:`torch.Tensor`): input tensor with shape (B, latent_state_dim). - Returns: - - policy (:obj:`torch.Tensor`): policy tensor with shape (B, action_space_size). - - value (:obj:`torch.Tensor`): value tensor with shape (B, output_support_size). - """ - return super(AfterstatePredictionNetworkMLP, self).forward(latent_state) diff --git a/lzero/policy/alphazero.py b/lzero/policy/alphazero.py index 28242ca76..ec8ff6a02 100644 --- a/lzero/policy/alphazero.py +++ b/lzero/policy/alphazero.py @@ -202,7 +202,7 @@ def _forward_learn(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, float]: 'value_loss': value_loss.item(), 'entropy_loss': entropy_loss.item(), 'total_grad_norm_before_clip': total_grad_norm_before_clip.item(), - 'collect_mcts_temperature': self.collect_mcts_temperature, + 'collect_mcts_temperature': self._collect_mcts_temperature, } def _init_collect(self) -> None: @@ -212,7 +212,7 @@ def _init_collect(self) -> None: """ self._collect_mcts = MCTS(self._cfg.mcts) self._collect_model = self._model - self.collect_mcts_temperature = 1 + self._collect_mcts_temperature = 1 @torch.no_grad() def _forward_collect(self, envs: Dict, obs: Dict, temperature: float = 1) -> Dict[str, torch.Tensor]: @@ -228,7 +228,7 @@ def _forward_collect(self, envs: Dict, obs: Dict, temperature: float = 1) -> Dic - output (:obj:`Dict[str, torch.Tensor]`): The dict of output, the key is env_id and the value is the \ the corresponding policy output in this timestep, including action, probs and so on. """ - self.collect_mcts_temperature = temperature + self._collect_mcts_temperature = temperature ready_env_id = list(envs.keys()) init_state = {env_id: obs[env_id]['board'] for env_id in ready_env_id} start_player_index = {env_id: obs[env_id]['current_player_index'] for env_id in ready_env_id} @@ -244,7 +244,7 @@ def _forward_collect(self, envs: Dict, obs: Dict, temperature: float = 1) -> Dic action, mcts_probs = self._collect_mcts.get_next_action( envs[env_id], policy_forward_fn=self._policy_value_fn, - temperature=self.collect_mcts_temperature, + temperature=self._collect_mcts_temperature, sample=True ) output[env_id] = { diff --git a/lzero/policy/efficientzero.py b/lzero/policy/efficientzero.py index 6dcc97636..4a4da03c5 100644 --- a/lzero/policy/efficientzero.py +++ b/lzero/policy/efficientzero.py @@ -493,7 +493,7 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: predicted_value_prefixs = predicted_value_prefixs.reshape(-1).unsqueeze(-1) return { - 'collect_mcts_temperature': self.collect_mcts_temperature, + 'collect_mcts_temperature': self._collect_mcts_temperature, 'collect_epsilon': self.collect_epsilon, 'cur_lr': self._optimizer.param_groups[0]['lr'], 'weighted_total_loss': weighted_total_loss.item(), @@ -529,7 +529,7 @@ def _init_collect(self) -> None: self._mcts_collect = MCTSCtree(self._cfg) else: self._mcts_collect = MCTSPtree(self._cfg) - self.collect_mcts_temperature = 1 + self._collect_mcts_temperature = 1 self.collect_epsilon = 0.0 def _forward_collect( @@ -565,7 +565,7 @@ def _forward_collect( ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. """ self._collect_model.eval() - self.collect_mcts_temperature = temperature + self._collect_mcts_temperature = temperature self.collect_epsilon = epsilon active_collect_env_num = data.shape[0] with torch.no_grad(): @@ -600,8 +600,7 @@ def _forward_collect( roots, self._collect_model, latent_state_roots, reward_hidden_state_roots, to_play ) - roots_visit_count_distributions = roots.get_distributions( - ) # shape: ``{list: batch_size} ->{list: action_space_size}`` + roots_visit_count_distributions = roots.get_distributions() roots_values = roots.get_values() # shape: {list: batch_size} data_id = [i for i in range(active_collect_env_num)] @@ -614,7 +613,7 @@ def _forward_collect( if self._cfg.eps.eps_greedy_exploration_in_collect: # eps-greedy collect action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( - distributions, temperature=self.collect_mcts_temperature, deterministic=True + distributions, temperature=self._collect_mcts_temperature, deterministic=True ) action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] if np.random.rand() < self.collect_epsilon: @@ -624,7 +623,7 @@ def _forward_collect( # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents # the index within the legal action set, rather than the index in the entire action set. action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( - distributions, temperature=self.collect_mcts_temperature, deterministic=False + distributions, temperature=self._collect_mcts_temperature, deterministic=False ) # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] @@ -701,8 +700,8 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, read roots.prepare_no_noise(value_prefix_roots, policy_logits, to_play) self._mcts_eval.search(roots, self._eval_model, latent_state_roots, reward_hidden_state_roots, to_play) - roots_visit_count_distributions = roots.get_distributions( - ) # shape: ``{list: batch_size} ->{list: action_space_size}`` + # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` + roots_visit_count_distributions = roots.get_distributions() roots_values = roots.get_values() # shape: {list: batch_size} data_id = [i for i in range(active_eval_env_num)] output = {i: None for i in data_id} diff --git a/lzero/policy/gumbel_muzero.py b/lzero/policy/gumbel_muzero.py index c0520efe4..40609c23b 100644 --- a/lzero/policy/gumbel_muzero.py +++ b/lzero/policy/gumbel_muzero.py @@ -196,7 +196,7 @@ def _init_learn(self) -> None: Learn mode init method. Called by ``self.__init__``. Initialize the learn model, optimizer and MCTS utils. """ assert self._cfg.optim_type in ['SGD', 'Adam', 'AdamW'], self._cfg.optim_type - # NOTE: in board_gmaes, for fixed lr 0.003, 'Adam' is better than 'SGD'. + # NOTE: in board_games, for fixed lr 0.003, 'Adam' is better than 'SGD'. if self._cfg.optim_type == 'SGD': self._optimizer = optim.SGD( self._model.parameters(), @@ -335,8 +335,6 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: reward_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device) consistency_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device) - gradient_scale = 1 / self._cfg.num_unroll_steps - # ============================================================== # the core recurrent_inference in Gumbel MuZero policy. # ============================================================== @@ -435,7 +433,7 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: predicted_rewards = predicted_rewards.reshape(-1).unsqueeze(-1) return { - 'collect_mcts_temperature': self.collect_mcts_temperature, + 'collect_mcts_temperature': self._collect_mcts_temperature, 'cur_lr': self._optimizer.param_groups[0]['lr'], 'weighted_total_loss': weighted_total_loss.item(), 'total_loss': loss.mean().item(), @@ -469,7 +467,7 @@ def _init_collect(self) -> None: self._mcts_collect = MCTSCtree(self._cfg) else: self._mcts_collect = MCTSPtree(self._cfg) - self.collect_mcts_temperature = 1 + self._collect_mcts_temperature = 1 def _forward_collect( self, @@ -504,7 +502,7 @@ def _forward_collect( ``pred_value``, ``policy_logits``. """ self._collect_model.eval() - self.collect_mcts_temperature = temperature + self._collect_mcts_temperature = temperature active_collect_env_num = data.shape[0] with torch.no_grad(): # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} @@ -531,8 +529,8 @@ def _forward_collect( roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, list(pred_values), policy_logits, to_play) self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play) - roots_visit_count_distributions = roots.get_distributions( - ) # shape: ``{list: batch_size} ->{list: action_space_size}`` + # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` + roots_visit_count_distributions = roots.get_distributions() roots_values = roots.get_values() # shape: {list: batch_size} roots_completed_values = roots.get_children_values(self._cfg.discount_factor, self._cfg.model.action_space_size) @@ -555,7 +553,7 @@ def _forward_collect( # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents # the index within the legal action set, rather than the index in the entire action set. action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( - distributions, temperature=self.collect_mcts_temperature, deterministic=False + distributions, temperature=self._collect_mcts_temperature, deterministic=False ) # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the # entire action set. @@ -577,7 +575,7 @@ def _forward_collect( def _init_eval(self) -> None: """ Overview: - Evaluate mode init method. Called by ``self.__init__``. Ininitialize the eval model and MCTS utils. + Evaluate mode init method. Called by ``self.__init__``. Initialize the eval model and MCTS utils. """ self._eval_model = self._model if self._cfg.mcts_ctree: @@ -630,8 +628,8 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1 roots.prepare_no_noise(reward_roots, list(pred_values), policy_logits, to_play) self._mcts_eval.search(roots, self._eval_model, latent_state_roots, to_play) - roots_visit_count_distributions = roots.get_distributions( - ) # shape: ``{list: batch_size} ->{list: action_space_size}`` + # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` + roots_visit_count_distributions = roots.get_distributions() roots_values = roots.get_values() # shape: {list: batch_size} # ============================================================== diff --git a/lzero/policy/muzero.py b/lzero/policy/muzero.py index 5aee5b82b..97ffa3ccd 100644 --- a/lzero/policy/muzero.py +++ b/lzero/policy/muzero.py @@ -14,7 +14,8 @@ from lzero.mcts import MuZeroMCTSPtree as MCTSPtree from lzero.model import ImageTransforms from lzero.policy import scalar_transform, InverseScalarTransform, cross_entropy_loss, phi_transform, \ - DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, negative_cosine_similarity, prepare_obs + DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, negative_cosine_similarity, \ + prepare_obs @POLICY_REGISTRY.register('muzero') @@ -214,7 +215,7 @@ def _init_learn(self) -> None: Learn mode init method. Called by ``self.__init__``. Initialize the learn model, optimizer and MCTS utils. """ assert self._cfg.optim_type in ['SGD', 'Adam'], self._cfg.optim_type - # NOTE: in board_gmaes, for fixed lr 0.003, 'Adam' is better than 'SGD'. + # NOTE: in board_games, for fixed lr 0.003, 'Adam' is better than 'SGD'. if self._cfg.optim_type == 'SGD': self._optimizer = optim.SGD( self._model.parameters(), @@ -342,16 +343,14 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in reward_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device) consistency_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device) - gradient_scale = 1 / self._cfg.num_unroll_steps - # ============================================================== # the core recurrent_inference in MuZero policy. # ============================================================== - for step_i in range(self._cfg.num_unroll_steps): + for step_k in range(self._cfg.num_unroll_steps): # unroll with the dynamics function: predict the next ``latent_state``, ``reward``, # given current ``latent_state`` and ``action``. # And then predict policy_logits and value with the prediction function. - network_output = self._learn_model.recurrent_inference(latent_state, action_batch[:, step_i]) + network_output = self._learn_model.recurrent_inference(latent_state, action_batch[:, step_k]) latent_state, reward, value, policy_logits = mz_network_output_unpack(network_output) # transform the scaled value or its categorical representation to its original value, @@ -363,17 +362,9 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # calculate consistency loss for the next ``num_unroll_steps`` unroll steps. # ============================================================== if self._cfg.ssl_loss_weight > 0: - # obtain the oracle hidden states from representation function. - if self._cfg.model.model_type == 'conv': - beg_index = self._cfg.model.image_channel * step_i - end_index = self._cfg.model.image_channel * (step_i + self._cfg.model.frame_stack_num) - network_output = self._learn_model.initial_inference( - obs_target_batch[:, beg_index:end_index, :, :] - ) - elif self._cfg.model.model_type == 'mlp': - beg_index = self._cfg.model.observation_shape * step_i - end_index = self._cfg.model.observation_shape * (step_i + self._cfg.model.frame_stack_num) - network_output = self._learn_model.initial_inference(obs_target_batch[:, beg_index:end_index]) + # obtain the oracle latent states from representation function. + beg_index, end_index = self._get_target_obs_index_in_step_k(step_k) + network_output = self._learn_model.initial_inference(obs_target_batch[:, beg_index:end_index]) latent_state = to_tensor(latent_state) representation_state = to_tensor(network_output.latent_state) @@ -381,7 +372,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # NOTE: no grad for the representation_state branch dynamic_proj = self._learn_model.project(latent_state, with_grad=True) observation_proj = self._learn_model.project(representation_state, with_grad=False) - temp_loss = negative_cosine_similarity(dynamic_proj, observation_proj) * mask_batch[:, step_i] + temp_loss = negative_cosine_similarity(dynamic_proj, observation_proj) * mask_batch[:, step_k] consistency_loss += temp_loss # NOTE: the target policy, target_value_categorical, target_reward_categorical is calculated in @@ -390,10 +381,10 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # calculate policy loss for the next ``num_unroll_steps`` unroll steps. # NOTE: the +=. # ============================================================== - policy_loss += cross_entropy_loss(policy_logits, target_policy[:, step_i + 1]) + policy_loss += cross_entropy_loss(policy_logits, target_policy[:, step_k + 1]) - value_loss += cross_entropy_loss(value, target_value_categorical[:, step_i + 1]) - reward_loss += cross_entropy_loss(reward, target_reward_categorical[:, step_i]) + value_loss += cross_entropy_loss(value, target_value_categorical[:, step_k + 1]) + reward_loss += cross_entropy_loss(reward, target_reward_categorical[:, step_k]) # Follow MuZero, set half gradient # latent_state.register_hook(lambda grad: grad * 0.5) @@ -414,8 +405,8 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # ============================================================== # weighted loss with masks (some invalid states which are out of trajectory.) loss = ( - self._cfg.ssl_loss_weight * consistency_loss + self._cfg.policy_loss_weight * policy_loss + - self._cfg.value_loss_weight * value_loss + self._cfg.reward_loss_weight * reward_loss + self._cfg.ssl_loss_weight * consistency_loss + self._cfg.policy_loss_weight * policy_loss + + self._cfg.value_loss_weight * value_loss + self._cfg.reward_loss_weight * reward_loss ) weighted_total_loss = (weights * loss).mean() @@ -442,7 +433,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in predicted_rewards = predicted_rewards.reshape(-1).unsqueeze(-1) return { - 'collect_mcts_temperature': self.collect_mcts_temperature, + 'collect_mcts_temperature': self._collect_mcts_temperature, 'collect_epsilon': self.collect_epsilon, 'cur_lr': self._optimizer.param_groups[0]['lr'], 'weighted_total_loss': weighted_total_loss.item(), @@ -466,6 +457,30 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in 'total_grad_norm_before_clip': total_grad_norm_before_clip.item() } + def _get_target_obs_index_in_step_k(self, step): + """ + Overview: + Get the begin index and end index of the target obs in step k. + Arguments: + - step (:obj:`int`): The current step k. + Returns: + - beg_index (:obj:`int`): The begin index of the target obs in step k. + - end_index (:obj:`int`): The end index of the target obs in step k. + Examples: + >>> self._cfg.model.model_type = 'conv' + >>> self._cfg.model.image_channel = 3 + >>> self._cfg.model.frame_stack_num = 4 + >>> self._get_target_obs_index_in_step_k(0) + >>> (0, 12) + """ + if self._cfg.model.model_type == 'conv': + beg_index = self._cfg.model.image_channel * step + end_index = self._cfg.model.image_channel * (step + self._cfg.model.frame_stack_num) + elif self._cfg.model.model_type == 'mlp': + beg_index = self._cfg.model.observation_shape * step + end_index = self._cfg.model.observation_shape * (step + self._cfg.model.frame_stack_num) + return beg_index, end_index + def _init_collect(self) -> None: """ Overview: @@ -476,7 +491,7 @@ def _init_collect(self) -> None: self._mcts_collect = MCTSCtree(self._cfg) else: self._mcts_collect = MCTSPtree(self._cfg) - self.collect_mcts_temperature = 1 + self._collect_mcts_temperature = 1 self.collect_epsilon = 0.0 def _forward_collect( @@ -486,7 +501,7 @@ def _forward_collect( temperature: float = 1, to_play: List = [-1], epsilon: float = 0.25, - ready_env_id = None + ready_env_id=None ) -> Dict: """ Overview: @@ -512,7 +527,7 @@ def _forward_collect( ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. """ self._collect_model.eval() - self.collect_mcts_temperature = temperature + self._collect_mcts_temperature = temperature self.collect_epsilon = epsilon active_collect_env_num = data.shape[0] with torch.no_grad(): @@ -540,8 +555,8 @@ def _forward_collect( roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play) self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play) - roots_visit_count_distributions = roots.get_distributions( - ) # shape: ``{list: batch_size} ->{list: action_space_size}`` + # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` + roots_visit_count_distributions = roots.get_distributions() roots_values = roots.get_values() # shape: {list: batch_size} data_id = [i for i in range(active_collect_env_num)] @@ -555,7 +570,7 @@ def _forward_collect( if self._cfg.eps.eps_greedy_exploration_in_collect: # eps greedy collect action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( - distributions, temperature=self.collect_mcts_temperature, deterministic=True + distributions, temperature=self._collect_mcts_temperature, deterministic=True ) action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] if np.random.rand() < self.collect_epsilon: @@ -565,7 +580,7 @@ def _forward_collect( # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents # the index within the legal action set, rather than the index in the entire action set. action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( - distributions, temperature=self.collect_mcts_temperature, deterministic=False + distributions, temperature=self._collect_mcts_temperature, deterministic=False ) # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] @@ -636,8 +651,8 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1 roots.prepare_no_noise(reward_roots, policy_logits, to_play) self._mcts_eval.search(roots, self._eval_model, latent_state_roots, to_play) - roots_visit_count_distributions = roots.get_distributions( - ) # shape: ``{list: batch_size} ->{list: action_space_size}`` + # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` + roots_visit_count_distributions = roots.get_distributions() roots_values = roots.get_values() # shape: {list: batch_size} data_id = [i for i in range(active_eval_env_num)] diff --git a/lzero/policy/random_policy.py b/lzero/policy/random_policy.py index 376ecc5b3..5a2c62c53 100644 --- a/lzero/policy/random_policy.py +++ b/lzero/policy/random_policy.py @@ -70,7 +70,7 @@ def _init_collect(self) -> None: self._mcts_collect = self.MCTSCtree(self._cfg) else: self._mcts_collect = self.MCTSPtree(self._cfg) - self.collect_mcts_temperature = 1 + self._collect_mcts_temperature = 1 self.collect_epsilon = 0.0 self.inverse_scalar_transform_handle = InverseScalarTransform( self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution @@ -109,7 +109,7 @@ def _forward_collect( ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. """ self._collect_model.eval() - self.collect_mcts_temperature = temperature + self._collect_mcts_temperature = temperature active_collect_env_num = data.shape[0] with torch.no_grad(): # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} @@ -155,8 +155,7 @@ def _forward_collect( else: raise NotImplementedError("need to implement pipeline: {}".format(self._cfg.type)) - roots_visit_count_distributions = roots.get_distributions( - ) # shape: ``{list: batch_size} ->{list: action_space_size}`` + roots_visit_count_distributions = roots.get_distributions() roots_values = roots.get_values() # shape: {list: batch_size} data_id = [i for i in range(active_collect_env_num)] @@ -169,7 +168,7 @@ def _forward_collect( # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents # the index within the legal action set, rather than the index in the entire action set. action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( - distributions, temperature=self.collect_mcts_temperature, deterministic=False + distributions, temperature=self._collect_mcts_temperature, deterministic=False ) # ****** sample a random action from the legal action set ******** diff --git a/lzero/policy/sampled_efficientzero.py b/lzero/policy/sampled_efficientzero.py index 9691983cb..03e10468a 100644 --- a/lzero/policy/sampled_efficientzero.py +++ b/lzero/policy/sampled_efficientzero.py @@ -400,8 +400,6 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: value_prefix_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device) consistency_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device) - gradient_scale = 1 / self._cfg.num_unroll_steps - # ============================================================== # the core recurrent_inference in SampledEfficientZero policy. # ============================================================== @@ -506,6 +504,8 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: self._cfg.policy_entropy_loss_weight * policy_entropy_loss ) weighted_total_loss = (weights * loss).mean() + + gradient_scale = 1 / self._cfg.num_unroll_steps weighted_total_loss.register_hook(lambda grad: grad * gradient_scale) self._optimizer.zero_grad() weighted_total_loss.backward() @@ -529,7 +529,7 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: return_data = { 'cur_lr': self._optimizer.param_groups[0]['lr'], - 'collect_mcts_temperature': self.collect_mcts_temperature, + 'collect_mcts_temperature': self._collect_mcts_temperature, 'weighted_total_loss': weighted_total_loss.item(), 'total_loss': loss.mean().item(), 'policy_loss': policy_loss.mean().item(), @@ -788,7 +788,7 @@ def _init_collect(self) -> None: self._mcts_collect = MCTSCtree(self._cfg) else: self._mcts_collect = MCTSPtree(self._cfg) - self.collect_mcts_temperature = 1 + self._collect_mcts_temperature = 1 def _forward_collect( self, data: torch.Tensor, action_mask: list = None, temperature: np.ndarray = 1, to_play=-1, epsilon: float = 0.25, ready_env_id=None @@ -817,7 +817,7 @@ def _forward_collect( ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. """ self._collect_model.eval() - self.collect_mcts_temperature = temperature + self._collect_mcts_temperature = temperature active_collect_env_num = data.shape[0] with torch.no_grad(): # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} @@ -869,8 +869,8 @@ def _forward_collect( roots, self._collect_model, latent_state_roots, reward_hidden_state_roots, to_play ) - roots_visit_count_distributions = roots.get_distributions( - ) # shape: ``{list: batch_size} ->{list: action_space_size}`` + # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` + roots_visit_count_distributions = roots.get_distributions() roots_values = roots.get_values() # shape: {list: batch_size} roots_sampled_actions = roots.get_sampled_actions() # {list: 1}->{list:6} @@ -889,7 +889,7 @@ def _forward_collect( # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents # the index within the legal action set, rather than the index in the entire action set. action, visit_count_distribution_entropy = select_action( - distributions, temperature=self.collect_mcts_temperature, deterministic=False + distributions, temperature=self._collect_mcts_temperature, deterministic=False ) try: action = roots_sampled_actions[i][action].value @@ -995,8 +995,8 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, read roots.prepare_no_noise(value_prefix_roots, policy_logits, to_play) self._mcts_eval.search(roots, self._eval_model, latent_state_roots, reward_hidden_state_roots, to_play) - roots_visit_count_distributions = roots.get_distributions( - ) # shape: ``{list: batch_size} ->{list: action_space_size}`` + # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` + roots_visit_count_distributions = roots.get_distributions() roots_values = roots.get_values() # shape: {list: batch_size} # ============================================================== # sampled related core code diff --git a/lzero/policy/stochastic_muzero.py b/lzero/policy/stochastic_muzero.py index c0182632c..d0cf2ed7d 100644 --- a/lzero/policy/stochastic_muzero.py +++ b/lzero/policy/stochastic_muzero.py @@ -5,7 +5,6 @@ import torch import torch.optim as optim from ding.model import model_wrap -from ding.policy.base_policy import Policy from ding.torch_utils import to_tensor from ding.utils import POLICY_REGISTRY from torch.nn import L1Loss @@ -14,13 +13,14 @@ from lzero.mcts import StochasticMuZeroMCTSPtree as MCTSPtree from lzero.model import ImageTransforms from lzero.policy import scalar_transform, InverseScalarTransform, cross_entropy_loss, phi_transform, \ - DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, negative_cosine_similarity, prepare_obs, \ - configure_optimizers + DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, negative_cosine_similarity, \ + prepare_obs +from lzero.policy.muzero import MuZeroPolicy from lzero.policy.utils import plot_topk_accuracy, visualize_avg_softmax, plot_argmax_distribution @POLICY_REGISTRY.register('stochastic_muzero') -class StochasticMuZeroPolicy(Policy): +class StochasticMuZeroPolicy(MuZeroPolicy): """ Overview: The policy class for Stochastic MuZero proposed in the paper https://openreview.net/pdf?id=X6D9bAHhBQ1. @@ -58,8 +58,9 @@ class StochasticMuZeroPolicy(Policy): ), # ****** common ****** # (bool) Whether to enable the sampled-based algorithm (e.g. Sampled EfficientZero) - # this variable is used in ``collector``. sampled_algo=False, + # (bool) Whether to enable the gumbel-based algorithm (e.g. Gumbel Muzero). + gumbel_algo=False, # (bool) Whether to use C++ MCTS in policy. If False, use Python implementation. mcts_ctree=True, # (bool) Whether to use cuda for network. @@ -181,7 +182,7 @@ class StochasticMuZeroPolicy(Policy): eps=dict( # (bool) Whether to use eps greedy exploration in collecting data. eps_greedy_exploration_in_collect=False, - # 'linear', 'exp' + # (str) The type of decaying epsilon. Options are 'linear', 'exp'. type='linear', # (float) The start value of eps. start=1., @@ -202,7 +203,7 @@ def default_model(self) -> Tuple[str, List[str]]: - import_names (:obj:`List[str]`): The model class path list used in this algorithm. .. note:: The user can define and use customized network model but must obey the same interface definition indicated \ - by import_names path. For MuZero, ``lzero.model.muzero_model.MuZeroModel`` + by import_names path. For MuZero, ``lzero.model.muzero_model.MuZeroModel``. """ if self._cfg.model.model_type == "conv": return 'StochasticMuZeroModel', ['lzero.model.stochastic_muzero_model'] @@ -214,10 +215,10 @@ def default_model(self) -> Tuple[str, List[str]]: def _init_learn(self) -> None: """ Overview: - Learn mode init method. Called by ``self.__init__``. Ininitialize the learn model, optimizer and MCTS utils. + Learn mode init method. Called by ``self.__init__``. Initialize the learn model, optimizer and MCTS utils. """ assert self._cfg.optim_type in ['SGD', 'Adam'], self._cfg.optim_type - # NOTE: in board_gmaes, for fixed lr 0.003, 'Adam' is better than 'SGD'. + # NOTE: in board_games, for fixed lr 0.003, 'Adam' is better than 'SGD'. if self._cfg.optim_type == 'SGD': self._optimizer = optim.SGD( self._model.parameters(), @@ -261,11 +262,11 @@ def _init_learn(self) -> None: def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]: """ Overview: - The forward function for learning policy in learn mode, which is the core of the learning process. - The data is sampled from replay buffer. + The forward function for learning policy in learn mode, which is the core of the learning process. \ + The data is sampled from replay buffer. \ The loss is calculated by the loss function and the loss is backpropagated to update the model. Arguments: - - data (:obj:`Tuple[torch.Tensor]`): The data sampled from replay buffer, which is a tuple of tensors. + - data (:obj:`Tuple[torch.Tensor]`): The data sampled from replay buffer, which is a tuple of tensors. \ The first tensor is the current_batch, the second tensor is the target_batch. Returns: - info_dict (:obj:`Dict[str, Union[float, int]]`): The information dict to be logged, which contains \ @@ -286,15 +287,11 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in chance_one_hot_batch = torch.nn.functional.one_hot(chance_batch.long(), self._cfg.model.chance_space_size) obs_batch, obs_target_batch = prepare_obs(obs_batch_orig, self._cfg) - obs_list_for_chance_encoder = [] - obs_list_for_chance_encoder.append(obs_batch) - for i in range(self._cfg.num_unroll_steps): - beg_index = self._cfg.model.image_channel * i - end_index = self._cfg.model.image_channel * (i + self._cfg.model.frame_stack_num) - if self._cfg.model.model_type == 'conv': - obs_list_for_chance_encoder.append(obs_target_batch[:, beg_index:end_index, :, :]) - elif self._cfg.model.model_type == 'mlp': - obs_list_for_chance_encoder.append(obs_target_batch[:, beg_index*self._cfg.model.observation_shape:end_index*self._cfg.model.observation_shape]) + obs_list_for_chance_encoder = [obs_batch] + + for step_k in range(self._cfg.num_unroll_steps): + beg_index, end_index = self._get_target_obs_index_in_step_k(step_k) + obs_list_for_chance_encoder.append(obs_target_batch[:, beg_index:end_index]) # do augmentations if self._cfg.use_augmentation: @@ -365,18 +362,16 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in afterstate_value_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device) commitment_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device) - gradient_scale = 1 / self._cfg.num_unroll_steps - # ============================================================== # the core recurrent_inference in MuZero policy. # ============================================================== - for step_i in range(self._cfg.num_unroll_steps): + for step_k in range(self._cfg.num_unroll_steps): # unroll with the afterstate dynamic function: predict 'afterstate', # given current ``state`` and ``action``. # 'afterstate reward' is not used, we kept it for the sake of uniformity between decision nodes and chance nodes. # And then predict afterstate_policy_logits and afterstate_value with the afterstate prediction function. network_output = self._learn_model.recurrent_inference( - latent_state, action_batch[:, step_i], afterstate=False + latent_state, action_batch[:, step_k], afterstate=False ) afterstate, afterstate_reward, afterstate_value, afterstate_policy_logits = mz_network_output_unpack(network_output) @@ -384,13 +379,12 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # encode the consecutive frames to predict chance # ============================================================== # concat consecutive frames to predict chance - former_frame = obs_list_for_chance_encoder[step_i] - latter_frame = obs_list_for_chance_encoder[step_i + 1] - concat_frame = torch.cat((former_frame, latter_frame), dim=1) + concat_frame = torch.cat((obs_list_for_chance_encoder[step_k], + obs_list_for_chance_encoder[step_k + 1]), dim=1) chance_encoding, chance_one_hot = self._learn_model.chance_encode(concat_frame) if self._cfg.use_ture_chance_label_in_chance_encoder: - true_chance_code = chance_batch[:, step_i] - true_chance_one_hot = chance_one_hot_batch[:, step_i] + true_chance_code = chance_batch[:, step_k] + true_chance_one_hot = chance_one_hot_batch[:, step_k] chance_code = true_chance_code else: chance_code = torch.argmax(chance_encoding, dim=1).long().unsqueeze(-1) @@ -412,16 +406,8 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # ============================================================== if self._cfg.ssl_loss_weight > 0: # obtain the oracle hidden states from representation function. - if self._cfg.model.model_type == 'conv': - beg_index = self._cfg.model.image_channel * step_i - end_index = self._cfg.model.image_channel * (step_i + self._cfg.model.frame_stack_num) - network_output = self._learn_model.initial_inference( - obs_target_batch[:, beg_index:end_index, :, :] - ) - elif self._cfg.model.model_type == 'mlp': - beg_index = self._cfg.model.observation_shape * step_i - end_index = self._cfg.model.observation_shape * (step_i + self._cfg.model.frame_stack_num) - network_output = self._learn_model.initial_inference(obs_target_batch[:, beg_index:end_index]) + beg_index, end_index = self._get_target_obs_index_in_step_k(step_k) + network_output = self._learn_model.initial_inference(obs_target_batch[:, beg_index:end_index]) latent_state = to_tensor(latent_state) representation_state = to_tensor(network_output.latent_state) @@ -429,7 +415,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # NOTE: no grad for the representation_state branch dynamic_proj = self._learn_model.project(latent_state, with_grad=True) observation_proj = self._learn_model.project(representation_state, with_grad=False) - temp_loss = negative_cosine_similarity(dynamic_proj, observation_proj) * mask_batch[:, step_i] + temp_loss = negative_cosine_similarity(dynamic_proj, observation_proj) * mask_batch[:, step_k] consistency_loss += temp_loss # NOTE: the target policy, target_value_categorical, target_reward_categorical is calculated in @@ -438,7 +424,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # calculate policy loss for the next ``num_unroll_steps`` unroll steps. # NOTE: the +=. # ============================================================== - policy_loss += cross_entropy_loss(policy_logits, target_policy[:, step_i + 1]) + policy_loss += cross_entropy_loss(policy_logits, target_policy[:, step_k + 1]) if self._cfg.use_ture_chance_label_in_chance_encoder: afterstate_policy_loss += cross_entropy_loss(afterstate_policy_logits, true_chance_one_hot.detach()) @@ -466,12 +452,11 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # calculate the topK accuracy of afterstate_policy_logits and plot the topK accuracy curve. plot_topk_accuracy(afterstate_policy_logits, true_chance_one_hot, topK_values) - # TODO(pu): whether to detach the chance_encoding in the commitment loss. commitment_loss += torch.nn.MSELoss()(chance_encoding, chance_one_hot.float()) - afterstate_value_loss += cross_entropy_loss(afterstate_value, target_value_categorical[:, step_i]) - value_loss += cross_entropy_loss(value, target_value_categorical[:, step_i + 1]) - reward_loss += cross_entropy_loss(reward, target_reward_categorical[:, step_i]) + afterstate_value_loss += cross_entropy_loss(afterstate_value, target_value_categorical[:, step_k]) + value_loss += cross_entropy_loss(value, target_value_categorical[:, step_k + 1]) + reward_loss += cross_entropy_loss(reward, target_reward_categorical[:, step_k]) if self._cfg.monitor_extra_statistics: original_rewards = self.inverse_scalar_transform_handle(reward) @@ -544,7 +529,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in ) return { - 'collect_mcts_temperature': self.collect_mcts_temperature, + 'collect_mcts_temperature': self._collect_mcts_temperature, 'cur_lr': self._optimizer.param_groups[0]['lr'], 'weighted_total_loss': loss_info[0], 'total_loss': loss_info[1], @@ -561,6 +546,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # ============================================================== 'value_priority_orig': value_priority, 'value_priority': td_data[0].flatten().mean().item(), + 'target_reward': td_data[1].flatten().mean().item(), 'target_value': td_data[2].flatten().mean().item(), 'transformed_target_reward': td_data[3].flatten().mean().item(), @@ -573,14 +559,14 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in def _init_collect(self) -> None: """ Overview: - Collect mode init method. Called by ``self.__init__``. Ininitialize the collect model and MCTS utils. + Collect mode init method. Called by ``self.__init__``. Initialize the collect model and MCTS utils. """ self._collect_model = self._model if self._cfg.mcts_ctree: self._mcts_collect = MCTSCtree(self._cfg) else: self._mcts_collect = MCTSPtree(self._cfg) - self.collect_mcts_temperature = 1 + self._collect_mcts_temperature = 1 def _forward_collect( self, @@ -593,7 +579,7 @@ def _forward_collect( ) -> Dict: """ Overview: - The forward function for collecting data in collect mode. Use model to execute MCTS search. + The forward function for collecting data in collect mode. Use model to execute MCTS search. \ Choosing the action through sampling during the collect mode. Arguments: - data (:obj:`torch.Tensor`): The input data, i.e. the observation. @@ -603,9 +589,9 @@ def _forward_collect( - ready_env_id (:obj:`list`): The id of the env that is ready to collect. Shape: - data (:obj:`torch.Tensor`): - - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ + - For Atari, its shape is :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ S is the number of stacked frames, H is the height of the image, W is the width of the image. - - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. + - For lunarlander, its shape is :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. - temperature: :math:`(1, )`. - to_play: :math:`(N, 1)`, where N is the number of collect_env. @@ -615,7 +601,7 @@ def _forward_collect( ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. """ self._collect_model.eval() - self.collect_mcts_temperature = temperature + self._collect_mcts_temperature = temperature active_collect_env_num = data.shape[0] with torch.no_grad(): # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} @@ -644,8 +630,8 @@ def _forward_collect( roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play) self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play) - roots_visit_count_distributions = roots.get_distributions( - ) # shape: ``{list: batch_size} ->{list: action_space_size}`` + # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` + roots_visit_count_distributions = roots.get_distributions() roots_values = roots.get_values() # shape: {list: batch_size} data_id = [i for i in range(active_collect_env_num)] @@ -659,7 +645,7 @@ def _forward_collect( # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents # the index within the legal action set, rather than the index in the entire action set. action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( - distributions, temperature=self.collect_mcts_temperature, deterministic=False + distributions, temperature=self._collect_mcts_temperature, deterministic=False ) # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the # entire action set. @@ -678,7 +664,7 @@ def _forward_collect( def _init_eval(self) -> None: """ Overview: - Evaluate mode init method. Called by ``self.__init__``. Ininitialize the eval model and MCTS utils. + Evaluate mode init method. Called by ``self.__init__``. Initialize the eval model and MCTS utils. """ self._eval_model = self._model if self._cfg.mcts_ctree: @@ -689,7 +675,7 @@ def _init_eval(self) -> None: def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, ready_env_id=None) -> Dict: """ Overview: - The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. + The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. \ Choosing the action with the highest value (argmax) rather than sampling during the eval mode. Arguments: - data (:obj:`torch.Tensor`): The input data, i.e. the observation. @@ -731,8 +717,8 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1 roots.prepare_no_noise(reward_roots, policy_logits, to_play) self._mcts_eval.search(roots, self._eval_model, latent_state_roots, to_play) - roots_visit_count_distributions = roots.get_distributions( - ) # shape: ``{list: batch_size} ->{list: action_space_size}`` + # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` + roots_visit_count_distributions = roots.get_distributions() roots_values = roots.get_values() # shape: {list: batch_size} data_id = [i for i in range(active_eval_env_num)] @@ -768,7 +754,7 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1 def _monitor_vars_learn(self) -> List[str]: """ Overview: - Register the variables to be monitored in learn mode. The registered variables will be logged in + Register the variables to be monitored in learn mode. The registered variables will be logged in \ tensorboard according to the return value ``_forward_learn``. """ return [ diff --git a/lzero/policy/tests/config/__init__.py b/lzero/policy/tests/config/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lzero/policy/tests/config/atari_muzero_config_for_test.py b/lzero/policy/tests/config/atari_muzero_config_for_test.py new file mode 100644 index 000000000..05c64d419 --- /dev/null +++ b/lzero/policy/tests/config/atari_muzero_config_for_test.py @@ -0,0 +1,98 @@ +from easydict import EasyDict + +env_name = 'PongNoFrameskip-v4' + +if env_name == 'PongNoFrameskip-v4': + action_space_size = 6 +elif env_name == 'QbertNoFrameskip-v4': + action_space_size = 6 +elif env_name == 'MsPacmanNoFrameskip-v4': + action_space_size = 9 +elif env_name == 'SpaceInvadersNoFrameskip-v4': + action_space_size = 6 +elif env_name == 'BreakoutNoFrameskip-v4': + action_space_size = 4 + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 50 +update_per_collect = 1000 +batch_size = 256 +max_env_step = int(1e6) +reanalyze_ratio = 0. +eps_greedy_exploration_in_collect = False +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +atari_muzero_config = dict( + exp_name= + f'data_mz_ctree/{env_name[:-14]}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', + env=dict( + stop_value=int(1e6), + env_name=env_name, + obs_shape=(4, 96, 96), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + observation_shape=(4, 96, 96), + frame_stack_num=4, + action_space_size=action_space_size, + downsample=True, + self_supervised_learning_loss=True, # default is False + discrete_action_encoding_type='one_hot', + norm_type='BN', + ), + cuda=True, + env_type='not_board_games', + game_segment_length=400, + random_collect_episode_num=0, + eps=dict( + eps_greedy_exploration_in_collect=eps_greedy_exploration_in_collect, + # need to dynamically adjust the number of decay steps + # according to the characteristics of the environment and the algorithm + type='linear', + start=1., + end=0.05, + decay=int(1e5), + ), + use_augmentation=True, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='SGD', + lr_piecewise_constant_decay=True, + learning_rate=0.2, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + ssl_loss_weight=2, # default is 0 + n_episode=n_episode, + eval_freq=int(2e3), + replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), +) +atari_muzero_config = EasyDict(atari_muzero_config) +main_config = atari_muzero_config + +atari_muzero_create_config = dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='muzero', + import_names=['lzero.policy.muzero'], + ), +) +atari_muzero_create_config = EasyDict(atari_muzero_create_config) +create_config = atari_muzero_create_config \ No newline at end of file diff --git a/lzero/policy/tests/config/cartpole_muzero_config_for_test.py b/lzero/policy/tests/config/cartpole_muzero_config_for_test.py new file mode 100644 index 000000000..b7584d19f --- /dev/null +++ b/lzero/policy/tests/config/cartpole_muzero_config_for_test.py @@ -0,0 +1,74 @@ +from easydict import EasyDict + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 25 +update_per_collect = 100 +batch_size = 256 +max_env_step = int(1e5) +reanalyze_ratio = 0 +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +cartpole_muzero_config = dict( + exp_name=f'data_mz_ctree/cartpole_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', + env=dict( + env_name='CartPole-v0', + continuous=False, + manually_discretization=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + observation_shape=4, + action_space_size=2, + model_type='mlp', + lstm_hidden_size=128, + latent_state_dim=128, + self_supervised_learning_loss=True, # NOTE: default is False. + discrete_action_encoding_type='one_hot', + norm_type='BN', + ), + cuda=True, + env_type='not_board_games', + game_segment_length=50, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='Adam', + lr_piecewise_constant_decay=False, + learning_rate=0.003, + ssl_loss_weight=2, # NOTE: default is 0. + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + eval_freq=int(2e2), + replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), +) + +cartpole_muzero_config = EasyDict(cartpole_muzero_config) +main_config = cartpole_muzero_config + +cartpole_muzero_create_config = dict( + env=dict( + type='cartpole_lightzero', + import_names=['zoo.classic_control.cartpole.envs.cartpole_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='muzero', + import_names=['lzero.policy.muzero'], + ), +) +cartpole_muzero_create_config = EasyDict(cartpole_muzero_create_config) +create_config = cartpole_muzero_create_config \ No newline at end of file diff --git a/lzero/policy/tests/test_get_target_obs_index_in_step_k.py b/lzero/policy/tests/test_get_target_obs_index_in_step_k.py new file mode 100644 index 000000000..3558472f2 --- /dev/null +++ b/lzero/policy/tests/test_get_target_obs_index_in_step_k.py @@ -0,0 +1,73 @@ +import pytest +import torch +from ding.config import compile_config +from ding.policy import create_policy + +args = ['conv', 'mlp'] + +@pytest.mark.unittest +@pytest.mark.parametrize('test_mode_type', args) +def test_get_target_obs_index_in_step_k(test_mode_type): + """ + Overview: + Unit test for the _get_target_obs_index_in_step_k method. + We will test for two types of model_type: 'conv' and 'mlp'. + Arguments: + - test_mode_type (:obj:`str`): The type of model to test, which can be 'conv' or 'mlp'. + """ + # Import the relevant model and configuration + from lzero.model.muzero_model import MuZeroModel as Model + if test_mode_type == 'conv': + from lzero.policy.tests.config.atari_muzero_config_for_test import atari_muzero_config as cfg + from lzero.policy.tests.config.atari_muzero_config_for_test import atari_muzero_create_config as create_cfg + + elif test_mode_type == 'mlp': + from lzero.policy.tests.config.cartpole_muzero_config_for_test import cartpole_muzero_config as cfg + from lzero.policy.tests.config.cartpole_muzero_config_for_test import \ + cartpole_muzero_create_config as create_cfg + + # Create model + model = Model(**cfg.policy.model) + if cfg.policy.cuda and torch.cuda.is_available(): + cfg.policy.device = 'cuda' + else: + cfg.policy.device = 'cpu' + + # Compile configuration + cfg = compile_config(cfg, seed=0, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + + # Move model to the specified device and set it to evaluation mode + model.to(cfg.policy.device) + model.eval() + + # Create policy + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + + if test_mode_type == 'conv': + + # Test case 1: model_type = 'conv' + policy._cfg.model.model_type = 'conv' + # Assume the current step is 2 + step = 2 + # For 'conv' type, the expected start and end index should be (image_channel * step, image_channel * (step + frame_stack_num)) + expected_beg_index, expected_end_index = 2, 6 + # Get the actual start and end index + beg_index, end_index = policy._get_target_obs_index_in_step_k(step) + + # Assert that the actual start and end index match the expected ones + assert beg_index == expected_beg_index + assert end_index == expected_end_index + + elif test_mode_type == 'mlp': + # Test case 2: model_type = 'mlp' + policy._cfg.model.model_type = 'mlp' + # Assume the current step is 2 + step = 2 + # For 'mlp' type, the expected start and end index should be (observation_shape * step, observation_shape * (step + frame_stack_num)) + expected_beg_index, expected_end_index = 8, 12 + # Get the actual start and end index + beg_index, end_index = policy._get_target_obs_index_in_step_k(step) + + # Assert that the actual start and end index match the expected ones + assert beg_index == expected_beg_index + assert end_index == expected_end_index \ No newline at end of file diff --git a/zoo/classic_control/cartpole/config/cartpole_stochastic_muzero_config.py b/zoo/classic_control/cartpole/config/cartpole_stochastic_muzero_config.py index 4b21a830a..65dbf6418 100644 --- a/zoo/classic_control/cartpole/config/cartpole_stochastic_muzero_config.py +++ b/zoo/classic_control/cartpole/config/cartpole_stochastic_muzero_config.py @@ -39,7 +39,6 @@ norm_type='BN', ), mcts_ctree=True, - gumbel_algo=False, cuda=True, env_type='not_board_games', game_segment_length=50, @@ -53,7 +52,7 @@ reanalyze_ratio=reanalyze_ratio, n_episode=n_episode, eval_freq=int(2e2), - replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + replay_buffer_size=int(1e6), collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, ), diff --git a/zoo/game_2048/config/stochastic_muzero_2048_config.py b/zoo/game_2048/config/stochastic_muzero_2048_config.py index d298d3031..5b8d90d8b 100644 --- a/zoo/game_2048/config/stochastic_muzero_2048_config.py +++ b/zoo/game_2048/config/stochastic_muzero_2048_config.py @@ -28,10 +28,6 @@ env_name=env_name, obs_shape=(16, 4, 4), obs_type='dict_encoded_board', - raw_reward_type='raw', # 'merged_tiles_plus_log_max_tile_num' - reward_normalize=False, - reward_scale=100, - max_tile=int(2**16), # 2**11=2048, 2**16=65536 num_of_possible_chance_tile=num_of_possible_chance_tile, collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, @@ -45,13 +41,12 @@ chance_space_size=chance_space_size, image_channel=16, # NOTE: whether to use the self_supervised_learning_loss. default is False - self_supervised_learning_loss=True, # default is False + self_supervised_learning_loss=True, discrete_action_encoding_type='one_hot', norm_type='BN', ), use_ture_chance_label_in_chance_encoder=use_ture_chance_label_in_chance_encoder, mcts_ctree=True, - gumbel_algo=False, cuda=True, game_segment_length=200, update_per_collect=update_per_collect, @@ -68,7 +63,7 @@ ssl_loss_weight=2, # default is 0 n_episode=n_episode, eval_freq=int(2e3), - replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + replay_buffer_size=int(1e6), collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, ), diff --git a/zoo/game_2048/entry/2048_bot_eval.py b/zoo/game_2048/entry/2048_bot_eval.py index 1ca2a042c..42b263234 100644 --- a/zoo/game_2048/entry/2048_bot_eval.py +++ b/zoo/game_2048/entry/2048_bot_eval.py @@ -9,8 +9,8 @@ config = EasyDict(dict( env_name="game_2048", save_replay=False, - replay_format='mp4', - replay_name_suffix='test', + replay_format='gif', + replay_name_suffix='bot', replay_path=None, render_real_time=False, act_scale=True, @@ -42,13 +42,8 @@ grid = obs.astype(np.int64) # action = game_2048_env.human_to_action() # which obtain about 10000 score # action = game_2048_env.random_action() # which obtain about 1000 score - action = expectimax_search(grid) # which obtain about 58536 score - try: - obs, reward, done, info = game_2048_env.step(action) - except Exception as e: - print(f'Exception: {e}') - print('total_step_number: {}'.format(step)) - break + action = expectimax_search(grid) # which obtain about 300000~70000 score + obs, reward, done, info = game_2048_env.step(action) step += 1 print(f"step: {step}, action: {action}, reward: {reward}, raw_reward: {info['raw_reward']}") game_2048_env.render(mode='human') diff --git a/zoo/game_2048/entry/stochastic_muzero_2048_eval.py b/zoo/game_2048/entry/2048_eval.py similarity index 89% rename from zoo/game_2048/entry/stochastic_muzero_2048_eval.py rename to zoo/game_2048/entry/2048_eval.py index df8b0956f..2a9bb6f58 100644 --- a/zoo/game_2048/entry/stochastic_muzero_2048_eval.py +++ b/zoo/game_2048/entry/2048_eval.py @@ -2,6 +2,7 @@ import numpy as np from lzero.entry import eval_muzero +from zoo.game_2048.config.muzero_2048_config import main_config, create_config from zoo.game_2048.config.stochastic_muzero_2048_config import main_config, create_config if __name__ == "__main__": @@ -22,8 +23,10 @@ main_config.env.evaluator_env_num = 1 # Visualization requires the 'env_num' to be set as 1 main_config.env.n_evaluator_episode = total_test_episodes main_config.env.save_replay = True # Whether to save the replay, if save the video render_mode_human must to be True - main_config.env.replay_format = 'mp4' - main_config.env.replay_name_suffix = 'ns100_s1' + main_config.env.replay_format = 'gif' + main_config.env.replay_name_suffix = 'muzero_ns100_s0' + # main_config.env.replay_name_suffix = 'stochastic_muzero_ns100_s0' + main_config.env.replay_path = None main_config.env.max_episode_steps = int(1e9) # Adjust according to different environments diff --git a/zoo/game_2048/envs/game_2048_env.py b/zoo/game_2048/envs/game_2048_env.py index 95c45a01b..6873a72e9 100644 --- a/zoo/game_2048/envs/game_2048_env.py +++ b/zoo/game_2048/envs/game_2048_env.py @@ -85,27 +85,49 @@ class Game2048Env(gym.Env): # The default_config for game 2048 env. config = dict( + # (str) The name of the environment registered in the environment registry. env_name="game_2048", + # (bool) Whether to save the replay of the game. save_replay=False, + # (str) The format in which to save the replay. 'gif' is a popular choice. replay_format='gif', + # (str) A suffix for the replay file name to distinguish it from other files. replay_name_suffix='eval', + # (str or None) The directory in which to save the replay file. If None, the file is saved in the current directory. replay_path=None, + # (bool) Whether to render the game in real time. Useful for debugging, but can slow down training. render_real_time=False, + # (bool) Whether to scale the actions. If True, actions are divided by the action space size. act_scale=True, + # (bool) Whether to use the 'channel last' format for the observation space. If False, 'channel first' format is used. channel_last=True, - obs_type='dict_encoded_board', # options=['raw_board', 'raw_encoded_board', 'dict_encoded_board'] + # (str) The type of observation to use. Options are 'raw_board', 'raw_encoded_board', and 'dict_encoded_board'. + obs_type='dict_encoded_board', + # (bool) Whether to normalize rewards. If True, rewards are divided by the maximum possible reward. reward_normalize=False, + # (float) The factor to scale rewards by when reward normalization is used. reward_norm_scale=100, - reward_type='raw', # options=['raw', 'merged_tiles_plus_log_max_tile_num'] - max_tile=int(2 ** 16), # 2**11=2048, 2**16=65536 + # (str) The type of reward to use. 'raw' means the raw game score. 'merged_tiles_plus_log_max_tile_num' is an alternative. + reward_type='raw', + # (int) The maximum tile number in the game. A game is won when this tile appears. 2**11=2048, 2**16=65536 + max_tile=int(2 ** 16), + # (int) The number of steps to delay rewards by. If > 0, the agent only receives a reward every this many steps. delay_reward_step=0, + # (float) The probability that a random agent is used instead of the learning agent. prob_random_agent=0., + # (int) The maximum number of steps in an episode. max_episode_steps=int(1e6), + # (bool) Whether to collect data during the game. is_collect=True, + # (bool) Whether to ignore legal actions. If True, the agent can take any action, even if it's not legal. ignore_legal_actions=True, + # (bool) Whether to flatten the observation space. If True, the observation space is a 1D array instead of a 2D grid. need_flatten=False, + # (int) The number of possible tiles that can appear after each move. num_of_possible_chance_tile=2, + # (numpy array) The possible tiles that can appear after each move. possible_tiles=np.array([2, 4]), + # (numpy array) The probabilities corresponding to each possible tile. tile_probabilities=np.array([0.9, 0.1]), ) @@ -600,7 +622,7 @@ def human_to_action(self): try: action = int( input( - f"Enter the action (0, 1, 2, or 3, ) to play: " + f"Enter the action (0(Up), 1(Right), 2(Down), or 3(Left)) to play: " ) ) if action in self.legal_actions: @@ -673,11 +695,9 @@ def draw_tile(self, draw, x, y, o, fnt): text_x_size, text_y_size = bbox[2] - bbox[0], bbox[3] - bbox[1] draw.text((x * grid_size + (grid_size - text_x_size) // 2, y * grid_size + (grid_size - text_y_size) // 2), str(o), font=fnt, fill=white) - # assert text_x_size < grid_size - # assert text_y_size < grid_size def save_render_output(self, replay_name_suffix: str = '', replay_path=None, format='gif'): - # At the end of the episode, save the frames + # At the end of the episode, save the frames to a gif or mp4 file if replay_path is None: filename = f'game_2048_{replay_name_suffix}.{format}' else: From 7ea632f29c0a7edb375728ddb6169ceaa544637c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= <2402552459@qq.com> Date: Tue, 12 Sep 2023 19:01:16 +0800 Subject: [PATCH 28/28] polish(pu): use _get_target_obs_index_in_step_k in all policy, rename step_i to step_k --- lzero/policy/efficientzero.py | 49 ++++---- lzero/policy/gumbel_muzero.py | 53 +++++---- lzero/policy/muzero.py | 50 ++++---- lzero/policy/sampled_efficientzero.py | 111 +++++++++--------- lzero/policy/stochastic_muzero.py | 2 +- .../test_get_target_obs_index_in_step_k.py | 2 +- .../config/stochastic_muzero_2048_config.py | 2 +- 7 files changed, 135 insertions(+), 134 deletions(-) diff --git a/lzero/policy/efficientzero.py b/lzero/policy/efficientzero.py index 4a4da03c5..b18a9e297 100644 --- a/lzero/policy/efficientzero.py +++ b/lzero/policy/efficientzero.py @@ -5,7 +5,6 @@ import torch import torch.optim as optim from ding.model import model_wrap -from ding.policy.base_policy import Policy from ding.torch_utils import to_tensor from ding.utils import POLICY_REGISTRY from torch.distributions import Categorical @@ -15,15 +14,17 @@ from lzero.mcts import EfficientZeroMCTSPtree as MCTSPtree from lzero.model import ImageTransforms from lzero.policy import scalar_transform, InverseScalarTransform, cross_entropy_loss, phi_transform, \ - DiscreteSupport, select_action, to_torch_float_tensor, ez_network_output_unpack, negative_cosine_similarity, prepare_obs, \ + DiscreteSupport, select_action, to_torch_float_tensor, ez_network_output_unpack, negative_cosine_similarity, \ + prepare_obs, \ configure_optimizers +from lzero.policy.muzero import MuZeroPolicy @POLICY_REGISTRY.register('efficientzero') -class EfficientZeroPolicy(Policy): +class EfficientZeroPolicy(MuZeroPolicy): """ Overview: - The policy class for EfficientZero. + The policy class for EfficientZero proposed in the paper https://arxiv.org/abs/2111.00210. """ # The default_config for EfficientZero policy. @@ -179,7 +180,7 @@ class EfficientZeroPolicy(Policy): eps=dict( # (bool) Whether to use eps greedy exploration in collecting data. eps_greedy_exploration_in_collect=False, - # 'linear', 'exp' + # (str) The type of decaying epsilon. Options are 'linear', 'exp'. type='linear', # (float) The start value of eps. start=1., @@ -269,11 +270,11 @@ def _init_learn(self) -> None: def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: """ Overview: - The forward function for learning policy in learn mode, which is the core of the learning process. - The data is sampled from replay buffer. + The forward function for learning policy in learn mode, which is the core of the learning process. \ + The data is sampled from replay buffer. \ The loss is calculated by the loss function and the loss is backpropagated to update the model. Arguments: - - data (:obj:`Tuple[torch.Tensor]`): The data sampled from replay buffer, which is a tuple of tensors. + - data (:obj:`Tuple[torch.Tensor]`): The data sampled from replay buffer, which is a tuple of tensors. \ The first tensor is the current_batch, the second tensor is the target_batch. Returns: - info_dict (:obj:`Dict[str, Union[float, int]]`): The information dict to be logged, which contains \ @@ -373,12 +374,12 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: # ============================================================== # the core recurrent_inference in EfficientZero policy. # ============================================================== - for step_i in range(self._cfg.num_unroll_steps): + for step_k in range(self._cfg.num_unroll_steps): # unroll with the dynamics function: predict the next ``latent_state``, ``reward_hidden_state``, # `` value_prefix`` given current ``latent_state`` ``reward_hidden_state`` and ``action``. # And then predict policy_logits and value with the prediction function. network_output = self._learn_model.recurrent_inference( - latent_state, reward_hidden_state, action_batch[:, step_i] + latent_state, reward_hidden_state, action_batch[:, step_k] ) latent_state, value_prefix, reward_hidden_state, value, policy_logits = ez_network_output_unpack( network_output @@ -392,15 +393,9 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: # calculate consistency loss for the next ``num_unroll_steps`` unroll steps. # ============================================================== if self._cfg.ssl_loss_weight > 0: - # obtain the oracle hidden states from representation function. - if self._cfg.model.model_type == 'conv': - beg_index = self._cfg.model.image_channel * step_i - end_index = self._cfg.model.image_channel * (step_i + self._cfg.model.frame_stack_num) - network_output = self._learn_model.initial_inference(obs_target_batch[:, beg_index:end_index, :, :]) - elif self._cfg.model.model_type == 'mlp': - beg_index = self._cfg.model.observation_shape * step_i - end_index = self._cfg.model.observation_shape * (step_i + self._cfg.model.frame_stack_num) - network_output = self._learn_model.initial_inference(obs_target_batch[:, beg_index:end_index]) + # obtain the oracle latent states from representation function. + beg_index, end_index = self._get_target_obs_index_in_step_k(step_k) + network_output = self._learn_model.initial_inference(obs_target_batch[:, beg_index:end_index]) latent_state = to_tensor(latent_state) representation_state = to_tensor(network_output.latent_state) @@ -408,7 +403,7 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: # NOTE: no grad for the representation_state branch. dynamic_proj = self._learn_model.project(latent_state, with_grad=True) observation_proj = self._learn_model.project(representation_state, with_grad=False) - temp_loss = negative_cosine_similarity(dynamic_proj, observation_proj) * mask_batch[:, step_i] + temp_loss = negative_cosine_similarity(dynamic_proj, observation_proj) * mask_batch[:, step_k] consistency_loss += temp_loss @@ -418,16 +413,16 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: # calculate policy loss for the next ``num_unroll_steps`` unroll steps. # NOTE: the +=. # ============================================================== - policy_loss += cross_entropy_loss(policy_logits, target_policy[:, step_i + 1]) + policy_loss += cross_entropy_loss(policy_logits, target_policy[:, step_k + 1]) - # Here we take the hypothetical step k = step_i + 1 + # Here we take the hypothetical step k = step_k + 1 prob = torch.softmax(policy_logits, dim=-1) dist = Categorical(prob) policy_entropy += dist.entropy().mean() - target_normalized_visit_count = target_policy[:, step_i + 1] + target_normalized_visit_count = target_policy[:, step_k + 1] # ******* NOTE: target_policy_entropy is only for debug. ****** - non_masked_indices = torch.nonzero(mask_batch[:, step_i + 1]).squeeze(-1) + non_masked_indices = torch.nonzero(mask_batch[:, step_k + 1]).squeeze(-1) # Check if there are any unmasked rows if len(non_masked_indices) > 0: target_normalized_visit_count_masked = torch.index_select( @@ -439,11 +434,11 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: # Set target_policy_entropy to 0 if all rows are masked target_policy_entropy += 0 - value_loss += cross_entropy_loss(value, target_value_categorical[:, step_i + 1]) - value_prefix_loss += cross_entropy_loss(value_prefix, target_value_prefix_categorical[:, step_i]) + value_loss += cross_entropy_loss(value, target_value_categorical[:, step_k + 1]) + value_prefix_loss += cross_entropy_loss(value_prefix, target_value_prefix_categorical[:, step_k]) # reset hidden states every ``lstm_horizon_len`` unroll steps. - if (step_i + 1) % self._cfg.lstm_horizon_len == 0: + if (step_k + 1) % self._cfg.lstm_horizon_len == 0: reward_hidden_state = ( torch.zeros(1, self._cfg.batch_size, self._cfg.model.lstm_hidden_size).to(self._cfg.device), torch.zeros(1, self._cfg.batch_size, self._cfg.model.lstm_hidden_size).to(self._cfg.device) diff --git a/lzero/policy/gumbel_muzero.py b/lzero/policy/gumbel_muzero.py index 40609c23b..b4f88a387 100644 --- a/lzero/policy/gumbel_muzero.py +++ b/lzero/policy/gumbel_muzero.py @@ -5,7 +5,6 @@ import torch import torch.optim as optim from ding.model import model_wrap -from ding.policy.base_policy import Policy from ding.torch_utils import to_tensor from ding.utils import POLICY_REGISTRY from torch.nn import L1Loss, KLDivLoss @@ -14,15 +13,17 @@ from lzero.mcts import MuZeroMCTSPtree as MCTSPtree from lzero.model import ImageTransforms from lzero.policy import scalar_transform, InverseScalarTransform, cross_entropy_loss, phi_transform, \ - DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, negative_cosine_similarity, prepare_obs, \ + DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, negative_cosine_similarity, \ + prepare_obs, \ configure_optimizers +from lzero.policy.muzero import MuZeroPolicy @POLICY_REGISTRY.register('gumbel_muzero') -class GumeblMuZeroPolicy(Policy): +class GumeblMuZeroPolicy(MuZeroPolicy): """ Overview: - The policy class for Gumbel Muzero. Paper link: https://openreview.net/forum?id=bERaNdoegnO + The policy class for Gumbel MuZero proposed in the paper https://openreview.net/forum?id=bERaNdoegnO. """ # The default_config for Gumbel MuZero policy. @@ -169,6 +170,24 @@ class GumeblMuZeroPolicy(Policy): root_dirichlet_alpha=0.3, # (float) The noise weight at the root node of the search tree. root_noise_weight=0.25, + + # ****** Explore by random collect ****** + # (int) The number of episodes to collect data randomly before training. + random_collect_episode_num=0, + + # ****** Explore by eps greedy ****** + eps=dict( + # (bool) Whether to use eps greedy exploration in collecting data. + eps_greedy_exploration_in_collect=False, + # (str) The type of decaying epsilon. Options are 'linear', 'exp'. + type='linear', + # (float) The start value of eps. + start=1., + # (float) The end value of eps. + end=0.05, + # (int) The decay steps from start to end eps. + decay=int(1e5), + ), ) def default_model(self) -> Tuple[str, List[str]]: @@ -338,11 +357,11 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: # ============================================================== # the core recurrent_inference in Gumbel MuZero policy. # ============================================================== - for step_i in range(self._cfg.num_unroll_steps): + for step_k in range(self._cfg.num_unroll_steps): # unroll with the dynamics function: predict the next ``hidden_state``, ``reward``, # given current ``hidden_state`` and ``action``. # And then predict policy_logits and value with the prediction function. - network_output = self._learn_model.recurrent_inference(hidden_state, action_batch[:, step_i]) + network_output = self._learn_model.recurrent_inference(hidden_state, action_batch[:, step_k]) hidden_state, reward, value, policy_logits = mz_network_output_unpack(network_output) # transform the scaled value or its categorical representation to its original value, @@ -354,17 +373,9 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: # calculate consistency loss for the next ``num_unroll_steps`` unroll steps. # ============================================================== if self._cfg.ssl_loss_weight > 0: - # obtain the oracle hidden states from representation function. - if self._cfg.model.model_type == 'conv': - beg_index = self._cfg.model.image_channel * step_i - end_index = self._cfg.model.image_channel * (step_i + self._cfg.model.frame_stack_num) - network_output = self._learn_model.initial_inference( - obs_target_batch[:, beg_index:end_index, :, :] - ) - elif self._cfg.model.model_type == 'mlp': - beg_index = self._cfg.model.observation_shape * step_i - end_index = self._cfg.model.observation_shape * (step_i + self._cfg.model.frame_stack_num) - network_output = self._learn_model.initial_inference(obs_target_batch[:, beg_index:end_index]) + # obtain the oracle latent states from representation function. + beg_index, end_index = self._get_target_obs_index_in_step_k(step_k) + network_output = self._learn_model.initial_inference(obs_target_batch[:, beg_index:end_index]) hidden_state = to_tensor(hidden_state) representation_state = to_tensor(network_output.latent_state) @@ -372,7 +383,7 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: # NOTE: no grad for the representation_state branch dynamic_proj = self._learn_model.project(hidden_state, with_grad=True) observation_proj = self._learn_model.project(representation_state, with_grad=False) - temp_loss = negative_cosine_similarity(dynamic_proj, observation_proj) * mask_batch[:, step_i] + temp_loss = negative_cosine_similarity(dynamic_proj, observation_proj) * mask_batch[:, step_k] consistency_loss += temp_loss # NOTE: the target policy, target_value_categorical, target_reward_categorical is calculated in @@ -381,9 +392,9 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: # calculate policy loss for the next ``num_unroll_steps`` unroll steps. # NOTE: the +=. # ============================================================== - policy_loss += self.kl_loss(torch.log(torch.softmax(policy_logits, dim=1)),torch.from_numpy(improved_policy_batch[:, step_i + 1]).to(self._cfg.device).detach().float()).mean(dim=-1) * mask_batch[:,step_i+1] - value_loss += cross_entropy_loss(value, target_value_categorical[:, step_i + 1]) - reward_loss += cross_entropy_loss(reward, target_reward_categorical[:, step_i]) + policy_loss += self.kl_loss(torch.log(torch.softmax(policy_logits, dim=1)),torch.from_numpy(improved_policy_batch[:, step_k + 1]).to(self._cfg.device).detach().float()).mean(dim=-1) * mask_batch[:,step_k+1] + value_loss += cross_entropy_loss(value, target_value_categorical[:, step_k + 1]) + reward_loss += cross_entropy_loss(reward, target_reward_categorical[:, step_k]) entropy_loss += -torch.sum(torch.softmax(policy_logits, dim=1) * torch.log(torch.softmax(policy_logits, dim=1)), dim=-1) # Follow MuZero, set half gradient diff --git a/lzero/policy/muzero.py b/lzero/policy/muzero.py index 97ffa3ccd..a72f6e748 100644 --- a/lzero/policy/muzero.py +++ b/lzero/policy/muzero.py @@ -179,7 +179,7 @@ class MuZeroPolicy(Policy): eps=dict( # (bool) Whether to use eps greedy exploration in collecting data. eps_greedy_exploration_in_collect=False, - # 'linear', 'exp' + # (str) The type of decaying epsilon. Options are 'linear', 'exp'. type='linear', # (float) The start value of eps. start=1., @@ -457,30 +457,6 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in 'total_grad_norm_before_clip': total_grad_norm_before_clip.item() } - def _get_target_obs_index_in_step_k(self, step): - """ - Overview: - Get the begin index and end index of the target obs in step k. - Arguments: - - step (:obj:`int`): The current step k. - Returns: - - beg_index (:obj:`int`): The begin index of the target obs in step k. - - end_index (:obj:`int`): The end index of the target obs in step k. - Examples: - >>> self._cfg.model.model_type = 'conv' - >>> self._cfg.model.image_channel = 3 - >>> self._cfg.model.frame_stack_num = 4 - >>> self._get_target_obs_index_in_step_k(0) - >>> (0, 12) - """ - if self._cfg.model.model_type == 'conv': - beg_index = self._cfg.model.image_channel * step - end_index = self._cfg.model.image_channel * (step + self._cfg.model.frame_stack_num) - elif self._cfg.model.model_type == 'mlp': - beg_index = self._cfg.model.observation_shape * step - end_index = self._cfg.model.observation_shape * (step + self._cfg.model.frame_stack_num) - return beg_index, end_index - def _init_collect(self) -> None: """ Overview: @@ -606,6 +582,30 @@ def _init_eval(self) -> None: else: self._mcts_eval = MCTSPtree(self._cfg) + def _get_target_obs_index_in_step_k(self, step): + """ + Overview: + Get the begin index and end index of the target obs in step k. + Arguments: + - step (:obj:`int`): The current step k. + Returns: + - beg_index (:obj:`int`): The begin index of the target obs in step k. + - end_index (:obj:`int`): The end index of the target obs in step k. + Examples: + >>> self._cfg.model.model_type = 'conv' + >>> self._cfg.model.image_channel = 3 + >>> self._cfg.model.frame_stack_num = 4 + >>> self._get_target_obs_index_in_step_k(0) + >>> (0, 12) + """ + if self._cfg.model.model_type == 'conv': + beg_index = self._cfg.model.image_channel * step + end_index = self._cfg.model.image_channel * (step + self._cfg.model.frame_stack_num) + elif self._cfg.model.model_type == 'mlp': + beg_index = self._cfg.model.observation_shape * step + end_index = self._cfg.model.observation_shape * (step + self._cfg.model.frame_stack_num) + return beg_index, end_index + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, ready_env_id=None) -> Dict: """ Overview: diff --git a/lzero/policy/sampled_efficientzero.py b/lzero/policy/sampled_efficientzero.py index 03e10468a..fc52387e7 100644 --- a/lzero/policy/sampled_efficientzero.py +++ b/lzero/policy/sampled_efficientzero.py @@ -5,7 +5,6 @@ import torch import torch.optim as optim from ding.model import model_wrap -from ding.policy.base_policy import Policy from ding.torch_utils import to_tensor from ding.utils import POLICY_REGISTRY from ditk import logging @@ -16,18 +15,20 @@ from lzero.mcts import SampledEfficientZeroMCTSPtree as MCTSPtree from lzero.model import ImageTransforms from lzero.policy import scalar_transform, InverseScalarTransform, cross_entropy_loss, phi_transform, \ - DiscreteSupport, to_torch_float_tensor, ez_network_output_unpack, select_action, negative_cosine_similarity, prepare_obs, \ + DiscreteSupport, to_torch_float_tensor, ez_network_output_unpack, select_action, negative_cosine_similarity, \ + prepare_obs, \ configure_optimizers +from lzero.policy.muzero import MuZeroPolicy @POLICY_REGISTRY.register('sampled_efficientzero') -class SampledEfficientZeroPolicy(Policy): +class SampledEfficientZeroPolicy(MuZeroPolicy): """ Overview: - The policy class for Sampled EfficientZero. + The policy class for Sampled EfficientZero proposed in the paper https://arxiv.org/abs/2104.06303. """ - # The default_config for Sampled fEficientZero policy. + # The default_config for Sampled EfficientZero policy. config = dict( model=dict( # (str) The model type. For 1-dimensional vector obs, we use mlp model. For 3-dimensional image obs, we use conv model. @@ -64,7 +65,7 @@ class SampledEfficientZeroPolicy(Policy): # (bool) whether to use res connection in dynamics. res_connection_in_dynamics=True, # (str) The type of normalization in MuZero model. Options are ['BN', 'LN']. Default to 'LN'. - norm_type='BN', + norm_type='BN', ), # ****** common ****** # (bool) Whether to use multi-gpu training. @@ -194,7 +195,7 @@ class SampledEfficientZeroPolicy(Policy): eps=dict( # (bool) Whether to use eps greedy exploration in collecting data. eps_greedy_exploration_in_collect=False, - # 'linear', 'exp' + # (str) The type of decaying epsilon. Options are 'linear', 'exp'. type='linear', # (float) The start value of eps. start=1., @@ -224,6 +225,7 @@ def default_model(self) -> Tuple[str, List[str]]: return 'SampledEfficientZeroModelMLP', ['lzero.model.sampled_efficientzero_model_mlp'] else: raise ValueError("model type {} is not supported".format(self._cfg.model.model_type)) + def _init_learn(self) -> None: """ Overview: @@ -403,12 +405,12 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: # ============================================================== # the core recurrent_inference in SampledEfficientZero policy. # ============================================================== - for step_i in range(self._cfg.num_unroll_steps): + for step_k in range(self._cfg.num_unroll_steps): # unroll with the dynamics function: predict the next ``latent_state``, ``reward_hidden_state``, # `` value_prefix`` given current ``latent_state`` ``reward_hidden_state`` and ``action``. # And then predict policy_logits and value with the prediction function. network_output = self._learn_model.recurrent_inference( - latent_state, reward_hidden_state, action_batch[:, step_i] + latent_state, reward_hidden_state, action_batch[:, step_k] ) latent_state, value_prefix, reward_hidden_state, value, policy_logits = ez_network_output_unpack( network_output @@ -423,17 +425,9 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: # calculate consistency loss for the next ``num_unroll_steps`` unroll steps. # ============================================================== if self._cfg.ssl_loss_weight > 0: - # obtain the oracle hidden states from representation function. - if self._cfg.model.model_type == 'conv': - beg_index = self._cfg.model.image_channel * step_i - end_index = self._cfg.model.image_channel * (step_i + self._cfg.model.frame_stack_num) - network_output = self._learn_model.initial_inference( - obs_target_batch[:, beg_index:end_index, :, :] - ) - elif self._cfg.model.model_type == 'mlp': - beg_index = self._cfg.model.observation_shape * step_i - end_index = self._cfg.model.observation_shape * (step_i + self._cfg.model.frame_stack_num) - network_output = self._learn_model.initial_inference(obs_target_batch[:, beg_index:end_index]) + # obtain the oracle latent states from representation function. + beg_index, end_index = self._get_target_obs_index_in_step_k(step_k) + network_output = self._learn_model.initial_inference(obs_target_batch[:, beg_index:end_index]) latent_state = to_tensor(latent_state) representation_state = to_tensor(network_output.latent_state) @@ -441,7 +435,7 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: # NOTE: no grad for the representation_state branch. dynamic_proj = self._learn_model.project(latent_state, with_grad=True) observation_proj = self._learn_model.project(representation_state, with_grad=False) - temp_loss = negative_cosine_similarity(dynamic_proj, observation_proj) * mask_batch[:, step_i] + temp_loss = negative_cosine_similarity(dynamic_proj, observation_proj) * mask_batch[:, step_k] consistency_loss += temp_loss @@ -460,7 +454,7 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: target_policy, mask_batch, child_sampled_actions_batch, - unroll_step=step_i + 1 + unroll_step=step_k + 1 ) else: """discrete action space""" @@ -470,14 +464,14 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: target_policy, mask_batch, child_sampled_actions_batch, - unroll_step=step_i + 1 + unroll_step=step_k + 1 ) - value_loss += cross_entropy_loss(value, target_value_categorical[:, step_i + 1]) - value_prefix_loss += cross_entropy_loss(value_prefix, target_value_prefix_categorical[:, step_i]) + value_loss += cross_entropy_loss(value, target_value_categorical[:, step_k + 1]) + value_prefix_loss += cross_entropy_loss(value_prefix, target_value_prefix_categorical[:, step_k]) # reset hidden states every ``lstm_horizon_len`` unroll steps. - if (step_i + 1) % self._cfg.lstm_horizon_len == 0: + if (step_k + 1) % self._cfg.lstm_horizon_len == 0: reward_hidden_state = ( torch.zeros(1, self._cfg.batch_size, self._cfg.model.lstm_hidden_size).to(self._cfg.device), torch.zeros(1, self._cfg.batch_size, self._cfg.model.lstm_hidden_size).to(self._cfg.device) @@ -499,9 +493,9 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: # ============================================================== # weighted loss with masks (some invalid states which are out of trajectory.) loss = ( - self._cfg.ssl_loss_weight * consistency_loss + self._cfg.policy_loss_weight * policy_loss + - self._cfg.value_loss_weight * value_loss + self._cfg.reward_loss_weight * value_prefix_loss + - self._cfg.policy_entropy_loss_weight * policy_entropy_loss + self._cfg.ssl_loss_weight * consistency_loss + self._cfg.policy_loss_weight * policy_loss + + self._cfg.value_loss_weight * value_loss + self._cfg.reward_loss_weight * value_prefix_loss + + self._cfg.policy_entropy_loss_weight * policy_entropy_loss ) weighted_total_loss = (weights * loss).mean() @@ -528,28 +522,28 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: predicted_value_prefixs = predicted_value_prefixs.reshape(-1).unsqueeze(-1) return_data = { - 'cur_lr': self._optimizer.param_groups[0]['lr'], - 'collect_mcts_temperature': self._collect_mcts_temperature, - 'weighted_total_loss': weighted_total_loss.item(), - 'total_loss': loss.mean().item(), - 'policy_loss': policy_loss.mean().item(), - 'policy_entropy': policy_entropy.item() / (self._cfg.num_unroll_steps + 1), - 'target_policy_entropy': target_policy_entropy.item() / (self._cfg.num_unroll_steps + 1), - 'value_prefix_loss': value_prefix_loss.mean().item(), - 'value_loss': value_loss.mean().item(), - 'consistency_loss': consistency_loss.mean().item() / self._cfg.num_unroll_steps, + 'cur_lr': self._optimizer.param_groups[0]['lr'], + 'collect_mcts_temperature': self._collect_mcts_temperature, + 'weighted_total_loss': weighted_total_loss.item(), + 'total_loss': loss.mean().item(), + 'policy_loss': policy_loss.mean().item(), + 'policy_entropy': policy_entropy.item() / (self._cfg.num_unroll_steps + 1), + 'target_policy_entropy': target_policy_entropy.item() / (self._cfg.num_unroll_steps + 1), + 'value_prefix_loss': value_prefix_loss.mean().item(), + 'value_loss': value_loss.mean().item(), + 'consistency_loss': consistency_loss.mean().item() / self._cfg.num_unroll_steps, - # ============================================================== - # priority related - # ============================================================== - 'value_priority': value_priority.flatten().mean().item(), - 'value_priority_orig': value_priority, - 'target_value_prefix': target_value_prefix.detach().cpu().numpy().mean().item(), - 'target_value': target_value.detach().cpu().numpy().mean().item(), - 'transformed_target_value_prefix': transformed_target_value_prefix.detach().cpu().numpy().mean().item(), - 'transformed_target_value': transformed_target_value.detach().cpu().numpy().mean().item(), - 'predicted_value_prefixs': predicted_value_prefixs.detach().cpu().numpy().mean().item(), - 'predicted_values': predicted_values.detach().cpu().numpy().mean().item() + # ============================================================== + # priority related + # ============================================================== + 'value_priority': value_priority.flatten().mean().item(), + 'value_priority_orig': value_priority, + 'target_value_prefix': target_value_prefix.detach().cpu().numpy().mean().item(), + 'target_value': target_value.detach().cpu().numpy().mean().item(), + 'transformed_target_value_prefix': transformed_target_value_prefix.detach().cpu().numpy().mean().item(), + 'transformed_target_value': transformed_target_value.detach().cpu().numpy().mean().item(), + 'predicted_value_prefixs': predicted_value_prefixs.detach().cpu().numpy().mean().item(), + 'predicted_values': predicted_values.detach().cpu().numpy().mean().item() } if self._cfg.model.continuous_action_space: @@ -580,7 +574,7 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: 'target_sampled_actions_mean': target_sampled_actions[:, :].float().mean().item(), 'total_grad_norm_before_clip': total_grad_norm_before_clip.item() }) - + return return_data def _calculate_policy_loss_cont( @@ -680,9 +674,9 @@ def _calculate_policy_loss_cont( if self._cfg.policy_loss_type == 'KL': # KL divergence loss: sum( p* log(p/q) ) = sum( p*log(p) - p*log(q) )= sum( p*log(p)) - sum( p*log(q) ) policy_loss += ( - torch.exp(target_log_prob_sampled_actions.detach()) * - (target_log_prob_sampled_actions.detach() - log_prob_sampled_actions) - ).sum(-1) * mask_batch[:, unroll_step] + torch.exp(target_log_prob_sampled_actions.detach()) * + (target_log_prob_sampled_actions.detach() - log_prob_sampled_actions) + ).sum(-1) * mask_batch[:, unroll_step] elif self._cfg.policy_loss_type == 'cross_entropy': # cross_entropy loss: - sum(p * log (q) ) policy_loss += -torch.sum( @@ -767,9 +761,9 @@ def _calculate_policy_loss_disc( if self._cfg.policy_loss_type == 'KL': # KL divergence loss: sum( p* log(p/q) ) = sum( p*log(p) - p*log(q) )= sum( p*log(p)) - sum( p*log(q) ) policy_loss += ( - torch.exp(target_log_prob_sampled_actions.detach()) * - (target_log_prob_sampled_actions.detach() - log_prob_sampled_actions) - ).sum(-1) * mask_batch[:, unroll_step] + torch.exp(target_log_prob_sampled_actions.detach()) * + (target_log_prob_sampled_actions.detach() - log_prob_sampled_actions) + ).sum(-1) * mask_batch[:, unroll_step] elif self._cfg.policy_loss_type == 'cross_entropy': # cross_entropy loss: - sum(p * log (q) ) policy_loss += -torch.sum( @@ -791,7 +785,8 @@ def _init_collect(self) -> None: self._collect_mcts_temperature = 1 def _forward_collect( - self, data: torch.Tensor, action_mask: list = None, temperature: np.ndarray = 1, to_play=-1, epsilon: float = 0.25, ready_env_id=None + self, data: torch.Tensor, action_mask: list = None, temperature: np.ndarray = 1, to_play=-1, + epsilon: float = 0.25, ready_env_id=None ): """ Overview: diff --git a/lzero/policy/stochastic_muzero.py b/lzero/policy/stochastic_muzero.py index d0cf2ed7d..78f66213f 100644 --- a/lzero/policy/stochastic_muzero.py +++ b/lzero/policy/stochastic_muzero.py @@ -153,7 +153,7 @@ class StochasticMuZeroPolicy(MuZeroPolicy): # (float) The fixed temperature value for MCTS action selection, which is used to control the exploration. # The larger the value, the more exploration. This value is only used when manual_temperature_decay=False. fixed_temperature_value=0.25, - # (bool) Whether to use the true chance in MCTS. + # (bool) Whether to use the true chance in MCTS. If False, use the predicted chance. use_ture_chance_label_in_chance_encoder=False, # ****** Priority ****** diff --git a/lzero/policy/tests/test_get_target_obs_index_in_step_k.py b/lzero/policy/tests/test_get_target_obs_index_in_step_k.py index 3558472f2..2e993da43 100644 --- a/lzero/policy/tests/test_get_target_obs_index_in_step_k.py +++ b/lzero/policy/tests/test_get_target_obs_index_in_step_k.py @@ -5,6 +5,7 @@ args = ['conv', 'mlp'] + @pytest.mark.unittest @pytest.mark.parametrize('test_mode_type', args) def test_get_target_obs_index_in_step_k(test_mode_type): @@ -44,7 +45,6 @@ def test_get_target_obs_index_in_step_k(test_mode_type): policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) if test_mode_type == 'conv': - # Test case 1: model_type = 'conv' policy._cfg.model.model_type = 'conv' # Assume the current step is 2 diff --git a/zoo/game_2048/config/stochastic_muzero_2048_config.py b/zoo/game_2048/config/stochastic_muzero_2048_config.py index 5b8d90d8b..367124478 100644 --- a/zoo/game_2048/config/stochastic_muzero_2048_config.py +++ b/zoo/game_2048/config/stochastic_muzero_2048_config.py @@ -6,7 +6,7 @@ # ============================================================== env_name = 'game_2048' action_space_size = 4 -use_ture_chance_label_in_chance_encoder = True # option: {True, False} +use_ture_chance_label_in_chance_encoder = True collector_env_num = 8 n_episode = 8 evaluator_env_num = 3