Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(nyz): add basic task pipeline #24

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lzero/entry/train_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def train_muzero(
log_vars = learner.train(train_data, collector.envstep)

if cfg.policy.use_priority:
replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig'])
replay_buffer.update_priority(train_data, log_vars[0]['td_error_priority'])

if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
break
Expand Down
2 changes: 1 addition & 1 deletion lzero/entry/train_muzero_with_gym_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def train_muzero_with_gym_env(
log_vars = learner.train(train_data, collector.envstep)

if cfg.policy.use_priority:
replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig'])
replay_buffer.update_priority(train_data, log_vars[0]['priority'])

if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
break
Expand Down
204 changes: 102 additions & 102 deletions lzero/mcts/buffer/game_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,11 @@ def __init__(self, cfg: dict):
default_config = self.default_config()
default_config.update(cfg)
self._cfg = default_config
self._cfg = cfg
assert self._cfg.env_type in ['not_board_games', 'board_games']
self.replay_buffer_size = self._cfg.replay_buffer_size
self.batch_size = self._cfg.batch_size
self._alpha = self._cfg.priority_prob_alpha
self._beta = self._cfg.priority_prob_beta
self.alpha = self._cfg.priority_prob_alpha
self.beta = self._cfg.priority_prob_beta

self.game_segment_buffer = []
self.game_pos_priorities = []
Expand All @@ -80,6 +79,7 @@ def sample(
Returns:
- train_data (:obj:`List`): List of train data, including current_batch and target_batch.
"""
raise NotImplementedError

@abstractmethod
def _make_batch(self, orig_data: Any, reanalyze_ratio: float) -> Tuple[Any]:
Expand All @@ -96,98 +96,7 @@ def _make_batch(self, orig_data: Any, reanalyze_ratio: float) -> Tuple[Any]:
Returns:
- context (:obj:`Tuple`): reward_value_context, policy_re_context, policy_non_re_context, current_batch
"""
pass

def _sample_orig_data(self, batch_size: int) -> Tuple:
"""
Overview:
sample orig_data that contains:
game_segment_list: a list of game segments
pos_in_game_segment_list: transition index in game (relative index)
batch_index_list: the index of start transition of sampled minibatch in replay buffer
weights_list: the weight concerning the priority
make_time: the time the batch is made (for correctly updating replay buffer when data is deleted)
Arguments:
- batch_size (:obj:`int`): batch size
- beta: float the parameter in PER for calculating the priority
"""
assert self._beta > 0
num_of_transitions = self.get_num_of_transitions()
if self._cfg.use_priority is False:
self.game_pos_priorities = np.ones_like(self.game_pos_priorities)

# +1e-6 for numerical stability
probs = self.game_pos_priorities ** self._alpha + 1e-6
probs /= probs.sum()

# sample according to transition index
# TODO(pu): replace=True
batch_index_list = np.random.choice(num_of_transitions, batch_size, p=probs, replace=False)

if self._cfg.reanalyze_outdated is True:
# NOTE: used in reanalyze part
batch_index_list.sort()

weights_list = (num_of_transitions * probs[batch_index_list]) ** (-self._beta)
weights_list /= weights_list.max()

game_segment_list = []
pos_in_game_segment_list = []

for idx in batch_index_list:
game_segment_idx, pos_in_game_segment = self.game_segment_game_pos_look_up[idx]
game_segment_idx -= self.base_idx
game_segment = self.game_segment_buffer[game_segment_idx]

game_segment_list.append(game_segment)
pos_in_game_segment_list.append(pos_in_game_segment)

make_time = [time.time() for _ in range(len(batch_index_list))]

orig_data = (game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time)
return orig_data

def _preprocess_to_play_and_action_mask(
self, game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list
):
"""
Overview:
prepare the to_play and action_mask for the target obs in ``value_obs_list``
- to_play: {list: game_segment_batch_size * (num_unroll_steps+1)}
- action_mask: {list: game_segment_batch_size * (num_unroll_steps+1)}
"""
to_play = []
for bs in range(game_segment_batch_size):
to_play_tmp = list(
to_play_segment[bs][pos_in_game_segment_list[bs]:pos_in_game_segment_list[bs] +
self._cfg.num_unroll_steps + 1]
)
if len(to_play_tmp) < self._cfg.num_unroll_steps + 1:
# NOTE: the effective to play index is {1,2}, for null padding data, we set to_play=-1
to_play_tmp += [-1 for _ in range(self._cfg.num_unroll_steps + 1 - len(to_play_tmp))]
to_play.append(to_play_tmp)
to_play = sum(to_play, [])

if self._cfg.model.continuous_action_space is True:
# when the action space of the environment is continuous, action_mask[:] is None.
return to_play, None

action_mask = []
for bs in range(game_segment_batch_size):
action_mask_tmp = list(
action_mask_segment[bs][pos_in_game_segment_list[bs]:pos_in_game_segment_list[bs] +
self._cfg.num_unroll_steps + 1]
)
if len(action_mask_tmp) < self._cfg.num_unroll_steps + 1:
action_mask_tmp += [
list(np.ones(self._cfg.model.action_space_size, dtype=np.int8))
for _ in range(self._cfg.num_unroll_steps + 1 - len(action_mask_tmp))
]
action_mask.append(action_mask_tmp)
action_mask = to_list(action_mask)
action_mask = sum(action_mask, [])

return to_play, action_mask
raise NotImplementedError

@abstractmethod
def _prepare_reward_value_context(
Expand All @@ -206,7 +115,7 @@ def _prepare_reward_value_context(
- reward_value_context (:obj:`list`): value_obs_lst, value_mask, state_index_lst, rewards_lst, game_segment_lens,
td_steps_lst, action_mask_segment, to_play_segment
"""
pass
raise NotImplementedError

@abstractmethod
def _prepare_policy_non_reanalyzed_context(
Expand All @@ -222,7 +131,7 @@ def _prepare_policy_non_reanalyzed_context(
Returns:
- policy_non_re_context (:obj:`list`): state_index_lst, child_visits, game_segment_lens, action_mask_segment, to_play_segment
"""
pass
raise NotImplementedError

@abstractmethod
def _prepare_policy_reanalyzed_context(
Expand All @@ -239,7 +148,7 @@ def _prepare_policy_reanalyzed_context(
- policy_re_context (:obj:`list`): policy_obs_lst, policy_mask, state_index_lst, indices,
child_visits, game_segment_lens, action_mask_segment, to_play_segment
"""
pass
raise NotImplementedError

@abstractmethod
def _compute_target_reward_value(self, reward_value_context: List[Any], model: Any) -> List[np.ndarray]:
Expand All @@ -253,7 +162,7 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
- batch_value_prefixs (:obj:'np.ndarray): batch of value prefix
- batch_target_values (:obj:'np.ndarray): batch of value estimation
"""
pass
raise NotImplementedError

@abstractmethod
def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: Any) -> np.ndarray:
Expand All @@ -265,7 +174,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:
Returns:
- batch_target_policies_re
"""
pass
raise NotImplementedError

@abstractmethod
def _compute_target_policy_non_reanalyzed(
Expand All @@ -284,7 +193,7 @@ def _compute_target_policy_non_reanalyzed(
Returns:
- batch_target_policies_non_re
"""
pass
raise NotImplementedError

@abstractmethod
def update_priority(
Expand All @@ -297,7 +206,98 @@ def update_priority(
- train_data (:obj:`Optional[List[Optional[np.ndarray]]]`): training data to be updated priority.
- batch_priorities (:obj:`batch_priorities`): priorities to update to.
"""
pass
raise NotImplementedError

def _sample_orig_data(self, batch_size: int) -> Tuple:
"""
Overview:
sample orig_data that contains:
game_segment_list: a list of game segments
pos_in_game_segment_list: transition index in game (relative index)
batch_index_list: the index of start transition of sampled minibatch in replay buffer
weights_list: the weight concerning the priority
make_time: the time the batch is made (for correctly updating replay buffer when data is deleted)
Arguments:
- batch_size (:obj:`int`): batch size
- beta: float the parameter in PER for calculating the priority
"""
assert self.beta > 0, self.beta
num_of_transitions = self.get_num_of_transitions()
if self._cfg.use_priority is False:
self.game_pos_priorities = np.ones_like(self.game_pos_priorities)

# +1e-6 for numerical stability
probs = self.game_pos_priorities ** self.alpha + 1e-6
probs /= probs.sum()

# sample according to transition index
# TODO(pu): replace=True
batch_index_list = np.random.choice(num_of_transitions, batch_size, p=probs, replace=False)

if self._cfg.reanalyze_outdated is True:
# NOTE: used in reanalyze part
batch_index_list.sort()

weights_list = (num_of_transitions * probs[batch_index_list]) ** (-self.beta)
weights_list /= weights_list.max()

game_segment_list = []
pos_in_game_segment_list = []

for idx in batch_index_list:
game_segment_idx, pos_in_game_segment = self.game_segment_game_pos_look_up[idx]
game_segment_idx -= self.base_idx
game_segment = self.game_segment_buffer[game_segment_idx]

game_segment_list.append(game_segment)
pos_in_game_segment_list.append(pos_in_game_segment)

make_time = [time.time() for _ in range(len(batch_index_list))]

orig_data = (game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time)
return orig_data

def _preprocess_to_play_and_action_mask(
self, game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list
):
"""
Overview:
prepare the to_play and action_mask for the target obs in ``value_obs_list``
- to_play: {list: game_segment_batch_size * (num_unroll_steps+1)}
- action_mask: {list: game_segment_batch_size * (num_unroll_steps+1)}
"""
to_play = []
for bs in range(game_segment_batch_size):
to_play_tmp = list(
to_play_segment[bs][pos_in_game_segment_list[bs]:pos_in_game_segment_list[bs] +
self._cfg.num_unroll_steps + 1]
)
if len(to_play_tmp) < self._cfg.num_unroll_steps + 1:
# NOTE: the effective to play index is {1,2}, for null padding data, we set to_play=-1
to_play_tmp += [-1 for _ in range(self._cfg.num_unroll_steps + 1 - len(to_play_tmp))]
to_play.append(to_play_tmp)
to_play = sum(to_play, [])

if self._cfg.model.continuous_action_space is True:
# when the action space of the environment is continuous, action_mask[:] is None.
return to_play, None

action_mask = []
for bs in range(game_segment_batch_size):
action_mask_tmp = list(
action_mask_segment[bs][pos_in_game_segment_list[bs]:pos_in_game_segment_list[bs] +
self._cfg.num_unroll_steps + 1]
)
if len(action_mask_tmp) < self._cfg.num_unroll_steps + 1:
action_mask_tmp += [
list(np.ones(self._cfg.model.action_space_size, dtype=np.int8))
for _ in range(self._cfg.num_unroll_steps + 1 - len(action_mask_tmp))
]
action_mask.append(action_mask_tmp)
action_mask = to_list(action_mask)
action_mask = sum(action_mask, [])

return to_play, action_mask

def push_game_segments(self, data_and_meta: Any) -> None:
"""
Expand Down
29 changes: 14 additions & 15 deletions lzero/mcts/buffer/game_buffer_efficientzero.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, List
from typing import Any, List, TYPE_CHECKING
from easydict import EasyDict

import numpy as np
import torch
Expand All @@ -10,6 +11,9 @@
from lzero.policy import to_detach_cpu_numpy, concat_output, concat_output_value, inverse_scalar_transform
from .game_buffer_muzero import MuZeroGameBuffer

if TYPE_CHECKING:
from ding.policy import Policy


@BUFFER_REGISTRY.register('game_buffer_efficientzero')
class EfficientZeroGameBuffer(MuZeroGameBuffer):
Expand All @@ -18,22 +22,17 @@ class EfficientZeroGameBuffer(MuZeroGameBuffer):
The specific game buffer for EfficientZero policy.
"""

def __init__(self, cfg: dict):
super().__init__(cfg)
def __init__(self, cfg: EasyDict) -> None:
"""
Overview:
Use the default configuration mechanism. If a user passes in a cfg with a key that matches an existing key
in the default configuration, the user-provided value will override the default configuration. Otherwise,
the default configuration will be used.
"""
default_config = self.default_config()
default_config.update(cfg)
self._cfg = default_config
assert self._cfg.env_type in ['not_board_games', 'board_games']
super().__init__(cfg)
assert self._cfg.env_type in ['not_board_games', 'board_games'], self._cfg.env_type
self.replay_buffer_size = self._cfg.replay_buffer_size
self.batch_size = self._cfg.batch_size
self._alpha = self._cfg.priority_prob_alpha
self._beta = self._cfg.priority_prob_beta

self.game_segment_buffer = []
self.game_pos_priorities = []
Expand All @@ -44,15 +43,16 @@ def __init__(self, cfg: dict):
self.base_idx = 0
self.clear_time = 0

def sample(self, batch_size: int, policy: Any) -> List[Any]:
def sample(self, batch_size: int, policy: 'Policy') -> List[Any]:
"""
Overview:
sample data from ``GameBuffer`` and prepare the current and target batch for training
Sample a mini-batch of data for training, mainly including random sampling and preparing the current and \
target batch with/without reanalyzing operation mentioned in MuZero.
Arguments:
- batch_size (:obj:`int`): batch size
- policy (:obj:`torch.tensor`): model of policy
- batch_size (:obj:`int`): The number of samples in a mini-batch.
- policy (:obj:`Policy`): The policy instance used to execute reanalyzing operation.
Returns:
- train_data (:obj:`List`): List of train data
- train_data (:obj:`List`): List of sampled training data.
"""
policy._target_model.to(self._cfg.device)
policy._target_model.eval()
Expand Down Expand Up @@ -180,7 +180,6 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
to_play, action_mask = self._preprocess_to_play_and_action_mask(
game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list
)

legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)]

# ==============================================================
Expand Down
7 changes: 1 addition & 6 deletions lzero/mcts/buffer/game_buffer_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,16 @@ class MuZeroGameBuffer(GameBuffer):
"""

def __init__(self, cfg: dict):
super().__init__(cfg)
"""
Overview:
Use the default configuration mechanism. If a user passes in a cfg with a key that matches an existing key
in the default configuration, the user-provided value will override the default configuration. Otherwise,
the default configuration will be used.
"""
default_config = self.default_config()
default_config.update(cfg)
self._cfg = default_config
super().__init__(cfg)
assert self._cfg.env_type in ['not_board_games', 'board_games']
self.replay_buffer_size = self._cfg.replay_buffer_size
self.batch_size = self._cfg.batch_size
self._alpha = self._cfg.priority_prob_alpha
self._beta = self._cfg.priority_prob_beta

self.keep_ratio = 1
self.model_update_interval = 10
Expand Down
Loading