Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(rjy): add mamujoco env and related configs #153

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions lzero/mcts/buffer/game_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ def __init__(self, cfg: dict):
self.batch_size = self._cfg.batch_size
self._alpha = self._cfg.priority_prob_alpha
self._beta = self._cfg.priority_prob_beta
self._multi_agent = self._cfg.model.get('multi_agent', False)
if self._multi_agent:
self._num_agents = self._cfg.model.get('agent_num', 1)

self.game_segment_buffer = []
self.game_pos_priorities = []
Expand Down Expand Up @@ -344,9 +347,18 @@ def _push_game_segment(self, data: Any, meta: Optional[dict] = None) -> None:
)
else:
assert len(data) == len(meta['priorities']), " priorities should be of same length as the game steps"
priorities = meta['priorities'].copy().reshape(-1)
priorities[valid_len:len(data)] = 0.
self.game_pos_priorities = np.concatenate((self.game_pos_priorities, priorities))
if self._multi_agent:
priorities = meta['priorities'].copy()
priorities[valid_len:len(data)] = np.zeros_like(priorities[0])
if len(self.game_pos_priorities) == 0:
self.game_pos_priorities = priorities
else:
self.game_pos_priorities = np.concatenate((self.game_pos_priorities, priorities))

else:
priorities = priorities.reshape(-1)
priorities[valid_len:len(data)] = 0.
self.game_pos_priorities = np.concatenate((self.game_pos_priorities, priorities))

self.game_segment_buffer.append(data)
self.game_segment_game_pos_look_up += [
Expand Down
8 changes: 6 additions & 2 deletions lzero/mcts/buffer/game_buffer_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,12 +670,16 @@ def _compute_target_policy_non_reanalyzed(
else:
# NOTE: the invalid padding target policy, O is to make sure the correspoding cross_entropy_loss=0
policy_mask.append(0)
target_policies.append([0 for _ in range(policy_shape)])
if self._multi_agent:
target_policies.append([np.zeros_like(child_visit[0][0])] * self._cfg.model.agent_num)
else:
target_policies.append([0 for _ in range(policy_shape)])

policy_index += 1

batch_target_policies_non_re.append(target_policies)
batch_target_policies_non_re = np.asarray(batch_target_policies_non_re)
if not self._multi_agent:
batch_target_policies_non_re = np.asarray(batch_target_policies_non_re)
return batch_target_policies_non_re

def update_priority(self, train_data: List[np.ndarray], batch_priorities: Any) -> None:
Expand Down
90 changes: 60 additions & 30 deletions lzero/mcts/buffer/game_buffer_sampled_efficientzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import numpy as np
import torch
from ding.utils import BUFFER_REGISTRY
from ding.utils.data import default_collate, default_decollate
from ding.torch_utils import to_tensor, to_device, to_dtype, to_ndarray

from lzero.mcts.tree_search.mcts_ctree_sampled import SampledEfficientZeroMCTSCtree as MCTSCtree
from lzero.mcts.tree_search.mcts_ptree_sampled import SampledEfficientZeroMCTSPtree as MCTSPtree
Expand Down Expand Up @@ -140,7 +142,9 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
# sampled related core code
# ==============================================================
actions_tmp = game.action_segment[pos_in_game_segment:pos_in_game_segment +
self._cfg.num_unroll_steps].tolist()
self._cfg.num_unroll_steps]
if not isinstance(actions_tmp, list):
actions_tmp = actions_tmp.tolist()

# NOTE: self._cfg.num_unroll_steps + 1
root_sampled_actions_tmp = game.root_sampled_actions[pos_in_game_segment:pos_in_game_segment +
Expand All @@ -152,14 +156,25 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:

# pad random action
if self._cfg.model.continuous_action_space:
actions_tmp += [
np.random.randn(self._cfg.model.action_space_size)
if self._multi_agent:
actions_tmp += [
np.random.randn(self._cfg.model.agent_num, self._cfg.model.action_space_size)
for _ in range(self._cfg.num_unroll_steps - len(actions_tmp))
]
root_sampled_actions_tmp += [
root_sampled_actions_tmp += [
np.random.rand(self._cfg.model.agent_num, self._cfg.model.num_of_sampled_actions, self._cfg.model.action_space_size)
for _ in range(self._cfg.num_unroll_steps + 1 - len(root_sampled_actions_tmp))
]
else:
actions_tmp += [
np.random.randn(self._cfg.model.action_space_size)
for _ in range(self._cfg.num_unroll_steps - len(actions_tmp))
]
root_sampled_actions_tmp += [
np.random.rand(self._cfg.model.num_of_sampled_actions, self._cfg.model.action_space_size)
for _ in range(self._cfg.num_unroll_steps + 1 - len(root_sampled_actions_tmp))
]

else:
# generate random `padded actions_tmp`
actions_tmp += generate_random_actions_discrete(
Expand Down Expand Up @@ -192,7 +207,8 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
mask_list.append(mask_tmp)

# formalize the input observations
obs_list = prepare_observation(obs_list, self._cfg.model.model_type)
if not self._multi_agent:
obs_list = prepare_observation(obs_list, self._cfg.model.model_type)
# ==============================================================
# sampled related core code
# ==============================================================
Expand All @@ -202,7 +218,7 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
]

for i in range(len(current_batch)):
current_batch[i] = np.asarray(current_batch[i])
current_batch[i] = to_ndarray(current_batch[i])

total_transitions = self.get_num_of_transitions()

Expand Down Expand Up @@ -251,14 +267,14 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list, action_mask_segment, \
to_play_segment = reward_value_context # noqa

# transition_batch_size = game_segment_batch_size * (num_unroll_steps+1)
# transition_batch_size = game_segment_batch_size * (num_unroll_steps + 1)
transition_batch_size = len(value_obs_list)
game_segment_batch_size = len(pos_in_game_segment_list)

to_play, action_mask = self._preprocess_to_play_and_action_mask(
game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list
)
if self._cfg.model.continuous_action_space is True:
if self._cfg.model.continuous_action_space:
# when the action space of the environment is continuous, action_mask[:] is None.
action_mask = [
list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size)
Expand All @@ -272,15 +288,20 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A

batch_target_values, batch_value_prefixs = [], []
with torch.no_grad():
value_obs_list = prepare_observation(value_obs_list, self._cfg.model.model_type)
if not self._multi_agent:
value_obs_list = prepare_observation(value_obs_list, self._cfg.model.model_type)
# split a full batch into slices of mini_infer_size: to save the GPU memory for more GPU actors
slices = int(np.ceil(transition_batch_size / self._cfg.mini_infer_size))
network_output = []
for i in range(slices):
beg_index = self._cfg.mini_infer_size * i
end_index = self._cfg.mini_infer_size * (i + 1)
m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device).float()

if self._multi_agent:
m_obs = to_dtype(to_device(to_tensor(value_obs_list[beg_index:end_index]), self._cfg.device), torch.float)
m_obs = default_collate(m_obs)
m_obs = m_obs[0]
else:
m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device).float()
# calculate the target value
m_output = model.initial_inference(m_obs)

Expand Down Expand Up @@ -335,8 +356,7 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
)
roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play)
# do MCTS for a new policy with the recent target model
MCTSPtree.roots(self._cfg
).search(roots, model, latent_state_roots, reward_hidden_state_roots, to_play)
MCTSPtree.roots(self._cfg).search(roots, model, latent_state_roots, reward_hidden_state_roots, to_play)

roots_values = roots.get_values()
value_list = np.array(roots_values)
Expand All @@ -355,25 +375,32 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
]
)
else:
value_list = value_list.reshape(-1) * (
np.array([self._cfg.discount_factor for _ in range(transition_batch_size)]) ** td_steps_list
)
if self._multi_agent:
value_list = value_list.reshape(transition_batch_size, self._cfg.model.agent_num)
factor = np.array([self._cfg.discount_factor for _ in range(transition_batch_size)]) ** td_steps_list
value_list = value_list * factor.reshape(transition_batch_size, 1).astype(np.float32)
else:
value_list = value_list.reshape(-1) * (
np.array([self._cfg.discount_factor for _ in range(transition_batch_size)]) ** td_steps_list
)

value_list = value_list * np.array(value_mask)
value_list = value_list.tolist()
if self._multi_agent:
value_list = value_list * np.array(value_mask)[:, np.newaxis]
else:
value_list = value_list * np.array(value_mask)
value_list = value_list.tolist()

horizon_id, value_index = 0, 0
for game_segment_len_non_re, reward_list, state_index, to_play_list in zip(game_segment_lens, rewards_list,
pos_in_game_segment_list,
to_play_segment):
pos_in_game_segment_list,
to_play_segment):
target_values = []
target_value_prefixs = []

value_prefix = 0.0
base_index = state_index
for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1):
bootstrap_index = current_index + td_steps_list[value_index]
# for i, reward in enumerate(game.rewards[current_index:bootstrap_index]):
for i, reward in enumerate(reward_list[current_index:bootstrap_index]):
if self._cfg.env_type == 'board_games' and to_play_segment[0][0] in [1, 2]:
# TODO(pu): for board_games, very important, to check
Expand All @@ -395,20 +422,23 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
target_values.append(value_list[value_index])
# Since the horizon is small and the discount_factor is close to 1.
# Compute the reward sum to approximate the value prefix for simplification
value_prefix += reward_list[current_index
] # * config.discount_factor ** (current_index - base_index)
value_prefix += reward_list[current_index]
target_value_prefixs.append(value_prefix)
else:
target_values.append(0)
target_value_prefixs.append(value_prefix)

if self._multi_agent:
target_values.append(np.zeros_like(value_list[0]))
target_value_prefixs.append(np.array([0,]))
else:
target_values.append(0)
target_value_prefixs.append(value_prefix)
value_index += 1

batch_value_prefixs.append(target_value_prefixs)
batch_target_values.append(target_values)

batch_value_prefixs = np.asarray(batch_value_prefixs, dtype=object)
batch_target_values = np.asarray(batch_target_values, dtype=object)
if not self._multi_agent:
batch_value_prefixs = np.asarray(batch_value_prefixs, dtype=np.float32)
batch_target_values = np.asarray(batch_target_values, dtype=np.float32)

return batch_value_prefixs, batch_target_values

Expand Down Expand Up @@ -557,8 +587,8 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:
policy_index += 1

batch_target_policies_re.append(target_policies)

batch_target_policies_re = np.array(batch_target_policies_re)
if not self._multi_agent:
batch_target_policies_re = np.array(batch_target_policies_re)

return batch_target_policies_re, root_sampled_actions

Expand Down
Loading