Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(pu/zt): add 2048 env and Stochastic MuZero #64

Merged
merged 32 commits into from
Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
27d9b01
add stochastic mz ptree
timothijoe Jun 10, 2023
3799b46
add stochastic mz ctree
timothijoe Jun 10, 2023
14e3822
add box2d, classic conrol, and 2048 config
Jun 13, 2023
06c0558
made corrections to the comments and naming issues
Jun 16, 2023
fa88aab
made corrections to the comments and naming issues
Jun 16, 2023
b7a3fba
ok
Jul 10, 2023
9168a1b
ok
Jul 10, 2023
648f747
ok
timothijoe Jul 10, 2023
4693272
ok
timothijoe Jul 11, 2023
7000eed
Merge branch 'main' of https://github.com/opendilab/LightZero into de…
puyuan1996 Aug 2, 2023
11b4b7b
polish(pu): polish game_2048_env
puyuan1996 Aug 2, 2023
6bd0310
polish(pu): polish chance encoder
puyuan1996 Aug 2, 2023
a77a5bf
Merge branch 'explicit_chance_branch' of https://github.com/timothijo…
puyuan1996 Aug 3, 2023
f85aec3
fix(pu): fix chance encoder related loss
puyuan1996 Aug 3, 2023
b22dea7
sync code
puyuan1996 Aug 4, 2023
860fda1
polish(pu): polish 2048 env, add env save_render_gif method, add 2048…
puyuan1996 Aug 8, 2023
6e13727
feature(pu): add stochastic muzero eval config
puyuan1996 Aug 9, 2023
6f519b6
polish(pu): polish 2048 save_replay method
puyuan1996 Aug 9, 2023
e5f6b08
feature(pu): add num_of_possible_chance_tile option in 2048 env
puyuan1996 Aug 9, 2023
7d6f4f1
polish(pu): delete collector filed in create config, move eval_config…
puyuan1996 Aug 18, 2023
1f82928
sync code
puyuan1996 Aug 19, 2023
3119258
Merge branch 'main' of https://github.com/opendilab/LightZero into de…
puyuan1996 Aug 23, 2023
3a00a35
polish(pu): polish 2048 rule_bot move method, polish 2048 env, polish…
puyuan1996 Sep 5, 2023
89ab22a
Merge branch 'main' of https://github.com/opendilab/LightZero into de…
puyuan1996 Sep 5, 2023
02046f4
feature(pu): add stochastic_muzero_model_mlp
puyuan1996 Sep 5, 2023
1cbce65
polish(pu): polish stochastic muzero configs
puyuan1996 Sep 5, 2023
b6e9006
feature(pu): add analyze utlis for chance distribution
puyuan1996 Sep 5, 2023
f4556ce
polish(pu): delete model_path personal info
puyuan1996 Sep 5, 2023
3b7bcb0
polish(pu): add TestVisualizationFunctions, polish stochastic muzero …
puyuan1996 Sep 10, 2023
1258be5
fix(pu): fix test_game_segment.py
puyuan1996 Sep 10, 2023
9e5b3d8
polish(pu): polish comments, abstract a get_target_obs_index_in_step_…
puyuan1996 Sep 12, 2023
7ea632f
polish(pu): use _get_target_obs_index_in_step_k in all policy, rename…
puyuan1996 Sep 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lzero/entry/eval_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def eval_muzero(
- policy (:obj:`Policy`): Converged policy.
"""
cfg, create_cfg = input_cfg
assert create_cfg.policy.type in ['efficientzero', 'muzero', 'sampled_efficientzero'], \
"LightZero now only support the following algo.: 'efficientzero', 'muzero', 'sampled_efficientzero'"
assert create_cfg.policy.type in ['efficientzero', 'muzero', 'stochastic_muzero', 'gumbel_muzero', 'sampled_efficientzero'], \
"LightZero now only support the following algo.: 'efficientzero', 'muzero', 'stochastic_muzero', 'gumbel_muzero', 'sampled_efficientzero'"

if cfg.policy.cuda and torch.cuda.is_available():
cfg.policy.device = 'cuda'
Expand Down
4 changes: 3 additions & 1 deletion lzero/entry/train_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def train_muzero(
"""

cfg, create_cfg = input_cfg
assert create_cfg.policy.type in ['efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero'], \
assert create_cfg.policy.type in ['efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero', 'stochastic_muzero'], \
"train_muzero entry now only support the following algo.: 'efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero'"

if create_cfg.policy.type == 'muzero':
Expand All @@ -58,6 +58,8 @@ def train_muzero(
from lzero.mcts import SampledEfficientZeroGameBuffer as GameBuffer
elif create_cfg.policy.type == 'gumbel_muzero':
from lzero.mcts import GumbelMuZeroGameBuffer as GameBuffer
elif create_cfg.policy.type == 'stochastic_muzero':
from lzero.mcts import StochasticMuZeroGameBuffer as GameBuffer

if cfg.policy.cuda and torch.cuda.is_available():
cfg.policy.device = 'cuda'
Expand Down
1 change: 1 addition & 0 deletions lzero/mcts/buffer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .game_buffer_efficientzero import EfficientZeroGameBuffer
from .game_buffer_sampled_efficientzero import SampledEfficientZeroGameBuffer
from .game_buffer_gumbel_muzero import GumbelMuZeroGameBuffer
from .game_buffer_stochastic_muzero import StochasticMuZeroGameBuffer
171 changes: 171 additions & 0 deletions lzero/mcts/buffer/game_buffer_stochastic_muzero.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
from typing import Any, Tuple, List

import numpy as np
from ding.utils import BUFFER_REGISTRY

from lzero.mcts.utils import prepare_observation
from .game_buffer_muzero import MuZeroGameBuffer


@BUFFER_REGISTRY.register('game_buffer_stochastic_muzero')
class StochasticMuZeroGameBuffer(MuZeroGameBuffer):
"""
Overview:
The specific game buffer for Stochastic MuZero policy.
"""

def __init__(self, cfg: dict):
super().__init__(cfg)
"""
Overview:
Use the default configuration mechanism. If a user passes in a cfg with a key that matches an existing key
in the default configuration, the user-provided value will override the default configuration. Otherwise,
the default configuration will be used.
"""
default_config = self.default_config()
default_config.update(cfg)
self._cfg = default_config
assert self._cfg.env_type in ['not_board_games', 'board_games']
self.replay_buffer_size = self._cfg.replay_buffer_size
self.batch_size = self._cfg.batch_size
self._alpha = self._cfg.priority_prob_alpha
self._beta = self._cfg.priority_prob_beta

self.keep_ratio = 1
self.model_update_interval = 10
self.num_of_collected_episodes = 0
self.base_idx = 0
self.clear_time = 0

self.game_segment_buffer = []
self.game_pos_priorities = []
self.game_segment_game_pos_look_up = []

def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
"""
Overview:
first sample orig_data through ``_sample_orig_data()``,
then prepare the context of a batch:
reward_value_context: the context of reanalyzed value targets
policy_re_context: the context of reanalyzed policy targets
policy_non_re_context: the context of non-reanalyzed policy targets
current_batch: the inputs of batch
Arguments:
- batch_size (:obj:`int`): the batch size of orig_data from replay buffer.
- reanalyze_ratio (:obj:`float`): ratio of reanalyzed policy (value is 100% reanalyzed)
Returns:
- context (:obj:`Tuple`): reward_value_context, policy_re_context, policy_non_re_context, current_batch
"""
# obtain the batch context from replay buffer
orig_data = self._sample_orig_data(batch_size)
game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time_list = orig_data
batch_size = len(batch_index_list)
obs_list, action_list, mask_list = [], [], []
if self._cfg.use_ture_chance_label_in_chance_encoder:
chance_list = []
# prepare the inputs of a batch
for i in range(batch_size):
game = game_segment_list[i]
pos_in_game_segment = pos_in_game_segment_list[i]

actions_tmp = game.action_segment[pos_in_game_segment:pos_in_game_segment +
self._cfg.num_unroll_steps].tolist()
if self._cfg.use_ture_chance_label_in_chance_encoder:
chances_tmp = game.chance_segment[1 + pos_in_game_segment:1 + pos_in_game_segment +
self._cfg.num_unroll_steps].tolist()
# add mask for invalid actions (out of trajectory)
mask_tmp = [1. for i in range(len(actions_tmp))]
mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps - len(mask_tmp))]

# pad random action
actions_tmp += [
np.random.randint(0, game.action_space_size)
for _ in range(self._cfg.num_unroll_steps - len(actions_tmp))
]
if self._cfg.use_ture_chance_label_in_chance_encoder:
chances_tmp += [
np.random.randint(0, game.action_space_size)
for _ in range(self._cfg.num_unroll_steps - len(chances_tmp))
]
# obtain the input observations
# pad if length of obs in game_segment is less than stack+num_unroll_steps
# e.g. stack+num_unroll_steps 4+5
obs_list.append(
game_segment_list[i].get_unroll_obs(
pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True
)
)
action_list.append(actions_tmp)
mask_list.append(mask_tmp)
if self._cfg.use_ture_chance_label_in_chance_encoder:
chance_list.append(chances_tmp)

# formalize the input observations
obs_list = prepare_observation(obs_list, self._cfg.model.model_type)

# formalize the inputs of a batch
if self._cfg.use_ture_chance_label_in_chance_encoder:
current_batch = [obs_list, action_list, mask_list, batch_index_list, weights_list, make_time_list,
chance_list]
else:
current_batch = [obs_list, action_list, mask_list, batch_index_list, weights_list, make_time_list]
for i in range(len(current_batch)):
current_batch[i] = np.asarray(current_batch[i])

total_transitions = self.get_num_of_transitions()

# obtain the context of value targets
reward_value_context = self._prepare_reward_value_context(
batch_index_list, game_segment_list, pos_in_game_segment_list, total_transitions
)
"""
only reanalyze recent reanalyze_ratio (e.g. 50%) data
if self._cfg.reanalyze_outdated is True, batch_index_list is sorted according to its generated env_steps
0: reanalyze_num -> reanalyzed policy, reanalyze_num:end -> non reanalyzed policy
"""
reanalyze_num = int(batch_size * reanalyze_ratio)
# reanalyzed policy
if reanalyze_num > 0:
# obtain the context of reanalyzed policy targets
policy_re_context = self._prepare_policy_reanalyzed_context(
batch_index_list[:reanalyze_num], game_segment_list[:reanalyze_num],
pos_in_game_segment_list[:reanalyze_num]
)
else:
policy_re_context = None

# non reanalyzed policy
if reanalyze_num < batch_size:
# obtain the context of non-reanalyzed policy targets
policy_non_re_context = self._prepare_policy_non_reanalyzed_context(
batch_index_list[reanalyze_num:], game_segment_list[reanalyze_num:],
pos_in_game_segment_list[reanalyze_num:]
)
else:
policy_non_re_context = None

context = reward_value_context, policy_re_context, policy_non_re_context, current_batch
return context

def update_priority(self, train_data: List[np.ndarray], batch_priorities: Any) -> None:
"""
Overview:
Update the priority of training data.
Arguments:
- train_data (:obj:`Optional[List[Optional[np.ndarray]]]`): training data to be updated priority.
- batch_priorities (:obj:`batch_priorities`): priorities to update to.
NOTE:
train_data = [current_batch, target_batch]
if self._cfg.use_ture_chance_label_in_chance_encoder:
obs_batch_orig, action_batch, mask_batch, indices, weights, make_time, chance_batch = current_batch
else:
obs_batch_orig, action_batch, mask_batch, indices, weights, make_time = current_batch

"""
indices = train_data[0][3]
metas = {'make_time': train_data[0][5], 'batch_priorities': batch_priorities}
# only update the priorities for data still in replay buffer
for i in range(len(indices)):
if metas['make_time'][i] > self.clear_time:
idx, prio = indices[i], metas['batch_priorities'][i]
self.game_pos_priorities[idx] = prio
57 changes: 37 additions & 20 deletions lzero/mcts/buffer/game_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,17 @@ def __init__(self, action_space: int, game_segment_length: int = 200, config: Ea
"""
self.action_space = action_space
self.game_segment_length = game_segment_length
self.config = config

self.num_unroll_steps = config.num_unroll_steps
self.td_steps = config.td_steps
self.frame_stack_num = config.model.frame_stack_num
self.discount_factor = config.discount_factor
self.action_space_size = config.model.action_space_size
self.gray_scale = config.gray_scale
self.transform2string = config.transform2string
self.sampled_algo = config.sampled_algo
self.gumbel_algo = config.gumbel_algo
self.use_ture_chance_label_in_chance_encoder = config.use_ture_chance_label_in_chance_encoder

if isinstance(config.model.observation_shape, int) or len(config.model.observation_shape) == 1:
# for vector obs input, e.g. classical control and box2d environments
self.zero_obs_shape = config.model.observation_shape
Expand All @@ -71,8 +77,11 @@ def __init__(self, action_space: int, game_segment_length: int = 200, config: Ea

self.improved_policy_probs = []

if self.config.sampled_algo:
if self.sampled_algo:
self.root_sampled_actions = []
if self.use_ture_chance_label_in_chance_encoder:
self.chance_segment = []


def get_unroll_obs(self, timestep: int, num_unroll_steps: int = 0, padding: bool = False) -> np.ndarray:
"""
Expand All @@ -89,8 +98,8 @@ def get_unroll_obs(self, timestep: int, num_unroll_steps: int = 0, padding: bool
if pad_len > 0:
pad_frames = np.array([stacked_obs[-1] for _ in range(pad_len)])
stacked_obs = np.concatenate((stacked_obs, pad_frames))
if self.config.transform2string:
stacked_obs = [jpeg_data_decompressor(obs, self.config.gray_scale) for obs in stacked_obs]
if self.transform2string:
stacked_obs = [jpeg_data_decompressor(obs, self.gray_scale) for obs in stacked_obs]
return stacked_obs

def zero_obs(self) -> List:
Expand All @@ -114,12 +123,10 @@ def get_obs(self) -> List:
assert timestep_obs == timestep_reward, "timestep_obs: {}, timestep_reward: {}".format(
timestep_obs, timestep_reward
)
# TODO:
timestep = timestep_obs
timestep = timestep_reward
stacked_obs = self.obs_segment[timestep:timestep + self.frame_stack_num]
if self.config.transform2string:
stacked_obs = [jpeg_data_decompressor(obs, self.config.gray_scale) for obs in stacked_obs]
if self.transform2string:
stacked_obs = [jpeg_data_decompressor(obs, self.gray_scale) for obs in stacked_obs]
return stacked_obs

def append(
Expand All @@ -128,7 +135,8 @@ def append(
obs: np.ndarray,
reward: np.ndarray,
action_mask: np.ndarray = None,
to_play: int = -1
to_play: int = -1,
chance: int = 0,
) -> None:
"""
Overview:
Expand All @@ -140,10 +148,12 @@ def append(

self.action_mask_segment.append(action_mask)
self.to_play_segment.append(to_play)
if self.use_ture_chance_label_in_chance_encoder:
self.chance_segment.append(chance)

def pad_over(
self, next_segment_observations: List, next_segment_rewards: List, next_segment_root_values: List,
next_segment_child_visits: List, next_segment_improved_policy: List = None
next_segment_child_visits: List, next_segment_improved_policy: List = None, next_chances: List = None,
) -> None:
"""
Overview:
Expand All @@ -158,15 +168,15 @@ def pad_over(
- next_segment_child_visits (:obj:`list`): root visit count distributions of MCTS from the next game_segment
- next_segment_improved_policy (:obj:`list`): root children select policy of MCTS from the next game_segment (Only used in Gumbel MuZero)
"""
assert len(next_segment_observations) <= self.config.num_unroll_steps
assert len(next_segment_child_visits) <= self.config.num_unroll_steps
assert len(next_segment_root_values) <= self.config.num_unroll_steps + self.config.td_steps
assert len(next_segment_rewards) <= self.config.num_unroll_steps + self.config.td_steps - 1
assert len(next_segment_observations) <= self.num_unroll_steps
assert len(next_segment_child_visits) <= self.num_unroll_steps
assert len(next_segment_root_values) <= self.num_unroll_steps + self.num_unroll_steps
assert len(next_segment_rewards) <= self.num_unroll_steps + self.num_unroll_steps - 1
# ==============================================================
# The core difference between GumbelMuZero and MuZero
# ==============================================================
if self.config.gumbel_algo:
assert len(next_segment_improved_policy) <= self.config.num_unroll_steps + self.config.td_steps
if self.gumbel_algo:
assert len(next_segment_improved_policy) <= self.num_unroll_steps + self.num_unroll_steps

# NOTE: next block observation should start from (stacked_observation - 1) in next trajectory
for observation in next_segment_observations:
Expand All @@ -181,9 +191,12 @@ def pad_over(
for child_visits in next_segment_child_visits:
self.child_visit_segment.append(child_visits)

if self.config.gumbel_algo:
if self.gumbel_algo:
for improved_policy in next_segment_improved_policy:
self.improved_policy_probs.append(improved_policy)
if self.use_ture_chance_label_in_chance_encoder:
for chances in next_chances:
self.chance_segment.append(chances)

def get_targets(self, timestep: int) -> Tuple:
"""
Expand All @@ -203,10 +216,10 @@ def store_search_stats(
if idx is None:
self.child_visit_segment.append([visit_count / sum_visits for visit_count in visit_counts])
self.root_value_segment.append(root_value)
if self.config.sampled_algo:
if self.sampled_algo:
self.root_sampled_actions.append(root_sampled_actions)
# store the improved policy in Gumbel Muzero: \pi'=softmax(logits + \sigma(CompletedQ))
if self.config.gumbel_algo:
if self.gumbel_algo:
self.improved_policy_probs.append(improved_policy)
else:
self.child_visit_segment[idx] = [visit_count / sum_visits for visit_count in visit_counts]
Expand Down Expand Up @@ -261,6 +274,8 @@ def game_segment_to_array(self) -> None:

self.action_mask_segment = np.array(self.action_mask_segment)
self.to_play_segment = np.array(self.to_play_segment)
if self.use_ture_chance_label_in_chance_encoder:
self.chance_segment = np.array(self.chance_segment)

def reset(self, init_observations: np.ndarray) -> None:
"""
Expand All @@ -279,6 +294,8 @@ def reset(self, init_observations: np.ndarray) -> None:

self.action_mask_segment = []
self.to_play_segment = []
if self.use_ture_chance_label_in_chance_encoder:
self.chance_segment = []

assert len(init_observations) == self.frame_stack_num

Expand Down
Empty file.
Loading
Loading