diff --git a/lzero/agent/efficientzero.py b/lzero/agent/efficientzero.py index 421cea881..bd8e6ff7b 100644 --- a/lzero/agent/efficientzero.py +++ b/lzero/agent/efficientzero.py @@ -110,6 +110,9 @@ def __init__( elif self.cfg.policy.model.model_type == 'conv': from lzero.model.efficientzero_model import EfficientZeroModel model = EfficientZeroModel(**self.cfg.policy.model) + elif self.cfg.policy.model.model_type == 'mlp_md': + from lzero.model.efficientzero_model_md import EfficientZeroModelMD + model = EfficientZeroModelMD(**self.cfg.policy.model) else: raise NotImplementedError if self.cfg.policy.cuda and torch.cuda.is_available(): @@ -124,8 +127,8 @@ def __init__( self.env_fn, self.collector_env_cfg, self.evaluator_env_cfg = get_vec_env_setting(self.cfg.env) def train( - self, - step: int = int(1e7), + self, + step: int = int(1e7), ) -> TrainingReturn: """ Overview: @@ -356,8 +359,8 @@ def deploy( return EvalReturn(eval_value=np.mean(reward_list), eval_value_std=np.std(reward_list)) def batch_evaluate( - self, - n_evaluator_episode: int = None, + self, + n_evaluator_episode: int = None, ) -> EvalReturn: """ Overview: diff --git a/lzero/agent/muzero.py b/lzero/agent/muzero.py index 55dda5d00..dfb691a69 100644 --- a/lzero/agent/muzero.py +++ b/lzero/agent/muzero.py @@ -110,6 +110,12 @@ def __init__( elif self.cfg.policy.model.model_type == 'conv': from lzero.model.muzero_model import MuZeroModel model = MuZeroModel(**self.cfg.policy.model) + elif self.cfg.policy.model.model_type == 'rgcn': + from lzero.model.muzero_model_gcn import MuZeroModelGCN + model = MuZeroModelGCN(**self.cfg.policy.model) + elif self.cfg.policy.model.model_type == 'mlp_md': + from lzero.model.muzero_model_md import MuZeroModelMD + model = MuZeroModelMD(**self.cfg.policy.model) else: raise NotImplementedError if self.cfg.policy.cuda and torch.cuda.is_available(): @@ -124,8 +130,8 @@ def __init__( self.env_fn, self.collector_env_cfg, self.evaluator_env_cfg = get_vec_env_setting(self.cfg.env) def train( - self, - step: int = int(1e7), + self, + step: int = int(1e7), ) -> TrainingReturn: """ Overview: @@ -356,8 +362,8 @@ def deploy( return EvalReturn(eval_value=np.mean(reward_list), eval_value_std=np.std(reward_list)) def batch_evaluate( - self, - n_evaluator_episode: int = None, + self, + n_evaluator_episode: int = None, ) -> EvalReturn: """ Overview: diff --git a/lzero/agent/sampled_efficientzero.py b/lzero/agent/sampled_efficientzero.py index 079bdd11d..a60dae859 100644 --- a/lzero/agent/sampled_efficientzero.py +++ b/lzero/agent/sampled_efficientzero.py @@ -93,7 +93,12 @@ def __init__( cfg.main_config.exp_name = exp_name self.origin_cfg = cfg self.cfg = compile_config( - cfg.main_config, seed=seed, env=None, auto=True, policy=SampledEfficientZeroPolicy, create_cfg=cfg.create_config + cfg.main_config, + seed=seed, + env=None, + auto=True, + policy=SampledEfficientZeroPolicy, + create_cfg=cfg.create_config ) self.exp_name = self.cfg.exp_name @@ -110,6 +115,9 @@ def __init__( elif self.cfg.policy.model.model_type == 'conv': from lzero.model.sampled_efficientzero_model import SampledEfficientZeroModel model = SampledEfficientZeroModel(**self.cfg.policy.model) + elif self.cfg.policy.model.model_type == 'mlp_md': + from lzero.model.sampled_efficientzero_model_md import SampledEfficientZeroModelMD + model = SampledEfficientZeroModelMD(**self.cfg.policy.model) else: raise NotImplementedError if self.cfg.policy.cuda and torch.cuda.is_available(): @@ -124,8 +132,8 @@ def __init__( self.env_fn, self.collector_env_cfg, self.evaluator_env_cfg = get_vec_env_setting(self.cfg.env) def train( - self, - step: int = int(1e7), + self, + step: int = int(1e7), ) -> TrainingReturn: """ Overview: @@ -356,8 +364,8 @@ def deploy( return EvalReturn(eval_value=np.mean(reward_list), eval_value_std=np.std(reward_list)) def batch_evaluate( - self, - n_evaluator_episode: int = None, + self, + n_evaluator_episode: int = None, ) -> EvalReturn: """ Overview: diff --git a/lzero/mcts/utils.py b/lzero/mcts/utils.py index c40052e62..11afad53a 100644 --- a/lzero/mcts/utils.py +++ b/lzero/mcts/utils.py @@ -6,8 +6,9 @@ from graphviz import Digraph -def generate_random_actions_discrete(num_actions: int, action_space_size: int, num_of_sampled_actions: int, - reshape=False): +def generate_random_actions_discrete( + num_actions: int, action_space_size: int, num_of_sampled_actions: int, reshape=False +): """ Overview: Generate a list of random actions. @@ -19,10 +20,7 @@ def generate_random_actions_discrete(num_actions: int, action_space_size: int, n Returns: A list of random actions. """ - actions = [ - np.random.randint(0, action_space_size, num_of_sampled_actions).reshape(-1) - for _ in range(num_actions) - ] + actions = [np.random.randint(0, action_space_size, num_of_sampled_actions).reshape(-1) for _ in range(num_actions)] # If num_of_sampled_actions == 1, flatten the actions to a list of numbers if num_of_sampled_actions == 1: @@ -97,7 +95,9 @@ def prepare_observation(observation_list, model_type='conv'): Returns: - np.ndarray: Reshaped array of observations. """ - assert model_type in ['conv', 'mlp'], "model_type must be either 'conv' or 'mlp'" + assert model_type in [ + 'conv', 'mlp', 'rgcn', 'mlp_md' + ], "model_type must be either 'conv', 'mlp', 'rgcn' or 'mlp_md'" observation_array = np.array(observation_list) batch_size = observation_array.shape[0] @@ -110,13 +110,27 @@ def prepare_observation(observation_list, model_type='conv'): _, stack_num, channels, width, height = observation_array.shape observation_array = observation_array.reshape(batch_size, stack_num * channels, width, height) - elif model_type == 'mlp': + elif model_type == 'mlp' or model_type == 'mlp_md': if observation_array.ndim == 3: # Flatten the last two dimensions observation_array = observation_array.reshape(batch_size, -1) else: raise ValueError("For 'mlp' model_type, the observation must have 3 dimensions [B, S, O]") + elif model_type == 'rgcn': + if observation_array.ndim == 4: + # TODO(rjy): strage process + # observation_array should be reshaped to [B, S*M, O], where M is the agent number + # now observation_array.shape = [B, S, M, O] + observation_array = observation_array.reshape(batch_size, -1, observation_array.shape[-1]) + elif observation_array.ndim == 3: + # Flatten the last two dimensions + observation_array = observation_array.reshape(batch_size, -1) + else: + raise ValueError( + "For 'rgcn' model_type, the observation must have 3 dimensions [B, S, O] or 4 dimensions [B, S, M, O]" + ) + return observation_array diff --git a/lzero/model/common.py b/lzero/model/common.py index 363f7f779..2b798a163 100644 --- a/lzero/model/common.py +++ b/lzero/model/common.py @@ -8,6 +8,8 @@ import math from typing import Optional, Tuple from dataclasses import dataclass +import logging +import itertools import numpy as np import torch import torch.nn as nn @@ -36,10 +38,14 @@ class MZNetworkOutput: class DownSample(nn.Module): - - def __init__(self, observation_shape: SequenceType, out_channels: int, activation: nn.Module = nn.ReLU(inplace=True), - norm_type: Optional[str] = 'BN', - ) -> None: + + def __init__( + self, + observation_shape: SequenceType, + out_channels: int, + activation: nn.Module = nn.ReLU(inplace=True), + norm_type: Optional[str] = 'BN', + ) -> None: """ Overview: Define downSample convolution network. Encode the observation into hidden state. @@ -72,11 +78,7 @@ def __init__(self, observation_shape: SequenceType, out_channels: int, activatio self.resblocks1 = nn.ModuleList( [ ResBlock( - in_channels=out_channels // 2, - activation=activation, - norm_type='BN', - res_type='basic', - bias=False + in_channels=out_channels // 2, activation=activation, norm_type='BN', res_type='basic', bias=False ) for _ in range(1) ] ) @@ -90,17 +92,15 @@ def __init__(self, observation_shape: SequenceType, out_channels: int, activatio ) self.resblocks2 = nn.ModuleList( [ - ResBlock( - in_channels=out_channels, activation=activation, norm_type='BN', res_type='basic', bias=False - ) for _ in range(1) + ResBlock(in_channels=out_channels, activation=activation, norm_type='BN', res_type='basic', bias=False) + for _ in range(1) ] ) self.pooling1 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) self.resblocks3 = nn.ModuleList( [ - ResBlock( - in_channels=out_channels, activation=activation, norm_type='BN', res_type='basic', bias=False - ) for _ in range(1) + ResBlock(in_channels=out_channels, activation=activation, norm_type='BN', res_type='basic', bias=False) + for _ in range(1) ] ) self.pooling2 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) @@ -174,15 +174,18 @@ def __init__( self.norm = nn.BatchNorm2d(num_channels) elif norm_type == 'LN': if downsample: - self.norm = nn.LayerNorm([num_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)]) + self.norm = nn.LayerNorm( + [num_channels, + math.ceil(observation_shape[-2] / 16), + math.ceil(observation_shape[-1] / 16)] + ) else: self.norm = nn.LayerNorm([num_channels, observation_shape[-2], observation_shape[-1]]) - + self.resblocks = nn.ModuleList( [ - ResBlock( - in_channels=num_channels, activation=activation, norm_type='BN', res_type='basic', bias=False - ) for _ in range(num_res_blocks) + ResBlock(in_channels=num_channels, activation=activation, norm_type='BN', res_type='basic', bias=False) + for _ in range(num_res_blocks) ] ) self.activation = activation @@ -223,13 +226,13 @@ def get_param_mean(self) -> float: class RepresentationNetworkMLP(nn.Module): def __init__( - self, - observation_shape: int, - hidden_channels: int = 64, - layer_num: int = 2, - activation: Optional[nn.Module] = nn.ReLU(inplace=True), - last_linear_layer_init_zero: bool = True, - norm_type: Optional[str] = 'BN', + self, + observation_shape: int, + hidden_channels: int = 64, + layer_num: int = 2, + activation: Optional[nn.Module] = nn.ReLU(inplace=True), + last_linear_layer_init_zero: bool = True, + norm_type: Optional[str] = 'BN', ) -> torch.Tensor: """ Overview: @@ -323,26 +326,35 @@ def __init__( self.resblocks = nn.ModuleList( [ - ResBlock( - in_channels=num_channels, activation=activation, norm_type='BN', res_type='basic', bias=False - ) for _ in range(num_res_blocks) + ResBlock(in_channels=num_channels, activation=activation, norm_type='BN', res_type='basic', bias=False) + for _ in range(num_res_blocks) ] ) self.conv1x1_value = nn.Conv2d(num_channels, value_head_channels, 1) self.conv1x1_policy = nn.Conv2d(num_channels, policy_head_channels, 1) - + if norm_type == 'BN': self.norm_value = nn.BatchNorm2d(value_head_channels) self.norm_policy = nn.BatchNorm2d(policy_head_channels) elif norm_type == 'LN': if downsample: - self.norm_value = nn.LayerNorm([value_head_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)]) - self.norm_policy = nn.LayerNorm([policy_head_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)]) + self.norm_value = nn.LayerNorm( + [value_head_channels, + math.ceil(observation_shape[-2] / 16), + math.ceil(observation_shape[-1] / 16)] + ) + self.norm_policy = nn.LayerNorm( + [ + policy_head_channels, + math.ceil(observation_shape[-2] / 16), + math.ceil(observation_shape[-1] / 16) + ] + ) else: self.norm_value = nn.LayerNorm([value_head_channels, observation_shape[-2], observation_shape[-1]]) self.norm_policy = nn.LayerNorm([policy_head_channels, observation_shape[-2], observation_shape[-1]]) - + self.flatten_output_size_for_value_head = flatten_output_size_for_value_head self.flatten_output_size_for_policy_head = flatten_output_size_for_policy_head self.activation = activation @@ -404,16 +416,16 @@ def forward(self, latent_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tenso class PredictionNetworkMLP(nn.Module): def __init__( - self, - action_space_size, - num_channels, - common_layer_num: int = 2, - fc_value_layers: SequenceType = [32], - fc_policy_layers: SequenceType = [32], - output_support_size: int = 601, - last_linear_layer_init_zero: bool = True, - activation: Optional[nn.Module] = nn.ReLU(inplace=True), - norm_type: Optional[str] = 'BN', + self, + action_space_size, + num_channels, + common_layer_num: int = 2, + fc_value_layers: SequenceType = [32], + fc_policy_layers: SequenceType = [32], + output_support_size: int = 601, + last_linear_layer_init_zero: bool = True, + activation: Optional[nn.Module] = nn.ReLU(inplace=True), + norm_type: Optional[str] = 'BN', ): """ Overview: diff --git a/lzero/model/common_gcn.py b/lzero/model/common_gcn.py new file mode 100644 index 000000000..73bd7a20c --- /dev/null +++ b/lzero/model/common_gcn.py @@ -0,0 +1,266 @@ +from typing import Optional, Tuple, Dict +import logging +import itertools + +import torch +import torch.nn as nn +from ding.torch_utils import MLP +from ding.utils import MODEL_REGISTRY, SequenceType + +from .utils import renormalize, get_params_mean, get_dynamic_mean, get_reward_mean + + +class RGCNLayer(nn.Module): + """ + Overview: + Relational graph convolutional network layer. + """ + + def __init__( + self, + robot_num: int, + human_num: int, + robot_state_dim, + human_state_dim, + similarity_function, + num_layer=2, + X_dim=32, + layerwise_graph=False, + skip_connection=True, + wr_dims=[64, 32], # the last dim should equal to X_dim + wh_dims=[64, 32], # the last dim should equal to X_dim + final_state_dim=32, # should equal to X_dim + norm_type=None, + last_linear_layer_init_zero=True, + activation: Optional[nn.Module] = nn.ReLU(inplace=True), + ): + super().__init__() + + # design choice + # 'gaussian', 'embedded_gaussian', 'cosine', 'cosine_softmax', 'concatenation' + self.similarity_function = similarity_function + self.robot_num = robot_num + self.human_num = human_num + self.robot_state_dim = robot_state_dim + self.human_state_dim = human_state_dim + self.num_layer = num_layer + self.X_dim = X_dim + self.layerwise_graph = layerwise_graph + self.skip_connection = skip_connection + + logging.info('Similarity_func: {}'.format(self.similarity_function)) + logging.info('Layerwise_graph: {}'.format(self.layerwise_graph)) + logging.info('Skip_connection: {}'.format(self.skip_connection)) + logging.info('Number of layers: {}'.format(self.num_layer)) + + self.w_r = MLP( + in_channels=robot_state_dim, + hidden_channels=wr_dims[0], + out_channels=wr_dims[1], + layer_num=num_layer, + activation=activation, + norm_type=norm_type, + last_linear_layer_init_zero=last_linear_layer_init_zero, + ) # inputs,64,32 + self.w_h = MLP( + in_channels=human_state_dim, + hidden_channels=wh_dims[0], + out_channels=wh_dims[1], + layer_num=num_layer, + activation=activation, + norm_type=norm_type, + last_linear_layer_init_zero=last_linear_layer_init_zero, + ) # inputs,64,32 + + if self.similarity_function == 'embedded_gaussian': + self.w_a = nn.Parameter(torch.randn(self.X_dim, self.X_dim)) + elif self.similarity_function == 'concatenation': + # TODO: fix the dim size + self.w_a = MLP( + in_channels=2 * X_dim, + hidden_channels=2 * X_dim, + out_channels=1, + layer_num=1, + ) + + embedding_dim = self.X_dim + self.Ws = torch.nn.ParameterList() + for i in range(self.num_layer): + if i == 0: + self.Ws.append(nn.Parameter(torch.randn(self.X_dim, embedding_dim))) + elif i == self.num_layer - 1: + self.Ws.append(nn.Parameter(torch.randn(embedding_dim, final_state_dim))) + else: + self.Ws.append(nn.Parameter(torch.randn(embedding_dim, embedding_dim))) + + # TODO: for visualization + self.A = None + + def compute_similarity_matrix(self, X): + if self.similarity_function == 'embedded_gaussian': + A = torch.matmul(torch.matmul(X, self.w_a), X.permute(0, 2, 1)) + normalized_A = nn.functional.softmax(A, dim=2) + elif self.similarity_function == 'gaussian': + A = torch.matmul(X, X.permute(0, 2, 1)) + normalized_A = nn.functional.softmax(A, dim=2) + elif self.similarity_function == 'cosine': + A = torch.matmul(X, X.permute(0, 2, 1)) + magnitudes = torch.norm(A, dim=2, keepdim=True) + norm_matrix = torch.matmul(magnitudes, magnitudes.permute(0, 2, 1)) + normalized_A = torch.div(A, norm_matrix) + elif self.similarity_function == 'cosine_softmax': + A = torch.matmul(X, X.permute(0, 2, 1)) + magnitudes = torch.norm(A, dim=2, keepdim=True) + norm_matrix = torch.matmul(magnitudes, magnitudes.permute(0, 2, 1)) + normalized_A = nn.functional.softmax(torch.div(A, norm_matrix), dim=2) + elif self.similarity_function == 'concatenation': + indices = [pair for pair in itertools.product(list(range(X.size(1))), repeat=2)] + selected_features = torch.index_select(X, dim=1, index=torch.LongTensor(indices).reshape(-1)) + pairwise_features = selected_features.reshape((-1, X.size(1) * X.size(1), X.size(2) * 2)) + A = self.w_a(pairwise_features).reshape(-1, X.size(1), X.size(1)) + normalized_A = A + elif self.similarity_function == 'squared': + A = torch.matmul(X, X.permute(0, 2, 1)) + squared_A = A * A + normalized_A = squared_A / torch.sum(squared_A, dim=2, keepdim=True) + elif self.similarity_function == 'equal_attention': + normalized_A = (torch.ones(X.size(1), X.size(1)) / X.size(1)).expand(X.size(0), X.size(1), X.size(1)) + elif self.similarity_function == 'diagonal': + normalized_A = (torch.eye(X.size(1), X.size(1))).expand(X.size(0), X.size(1), X.size(1)) + else: + raise NotImplementedError + + return normalized_A + + def forward(self, state): + state = state.to(self.w_r[0].weight.dtype) + if isinstance(state, dict): + robot_states = state['robot_state'] + human_states = state['human_state'] + elif isinstance(state, torch.Tensor): + if state.dim() == 3: + # state shape:(B, stack_num*(robot_num+human_num), state_dim) + stack_num = state.size(1) // (self.robot_num + self.human_num) + # robot_states shape:(B, stack_num*robot_num, state_dim) + robot_states = state[:, :stack_num * self.robot_num, :] + # human_states shape:(B, stack_num*human_num, state_dim) + human_states = state[:, stack_num * self.robot_num:, :] + elif state.dim() == 2: + # state shape:(B, stack_num*(robot_num+human_num)*state_dim) + stack_num = state.size(1) // ((self.robot_num + self.human_num) * self.robot_state_dim) + assert stack_num == 1, "stack_num should be 1 for 1-dim-array obs" + # robot_states shape:(B, stack_num*robot_num, state_dim) + robot_states = state[:, :stack_num * self.robot_num * + self.robot_state_dim].reshape(-1, self.robot_num, self.robot_state_dim) + # human_states shape:(B, stack_num*human_num, state_dim) + human_states = state[:, stack_num * self.robot_num * + self.robot_state_dim:].reshape(-1, self.human_num, self.human_state_dim) + + # compute feature matrix X + robot_state_embedings = self.w_r(robot_states) # batch x num x embedding_dim + human_state_embedings = self.w_h(human_states) + X = torch.cat([robot_state_embedings, human_state_embedings], dim=1) + + # compute matrix A + if not self.layerwise_graph: + normalized_A = self.compute_similarity_matrix(X) + self.A = normalized_A[0, :, :].data.cpu().numpy() # total_num x total_num + + # next_H = H = X + + H = X.contiguous().clone() + next_H = H.contiguous().clone() # batch x total_num x embedding_dim + for i in range(self.num_layer): # 2 + if self.layerwise_graph: # False + A = self.compute_similarity_matrix(H) + next_H = nn.functional.relu(torch.matmul(torch.matmul(A, H), self.Ws[i])) + else: # (A x H) x W_i + next_H = nn.functional.relu(torch.matmul(torch.matmul(normalized_A, H), self.Ws[i])) + + if self.skip_connection: + # next_H += H + next_H = next_H + H + H = next_H.contiguous().clone() + + return next_H + + +class RepresentationNetworkGCN(nn.Module): + + def __init__( + self, + robot_state_dim: int, + human_state_dim: int, + robot_num: int, + human_num: int, + hidden_channels: int = 64, + layer_num: int = 2, + activation: Optional[nn.Module] = nn.ReLU(inplace=True), + last_linear_layer_init_zero: bool = True, + norm_type: Optional[str] = 'BN', + ) -> torch.Tensor: + """ + Overview: + Representation network used in MuZero and derived algorithms. + Arguments: + - robot_state_dim (:obj:`int`): The dimension of robot state. + - human_state_dim (:obj:`int`): The dimension of human state. + - robot_num (:obj:`int`): The number of robots. + - human_num (:obj:`int`): The number of humans. + - num_res_blocks (:obj:`int`): The number of residual blocks. + - hidden_channels (:obj:`int`): The channel of output hidden state. + - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \ + defaults to True. This option is often used in video games like Atari. In board games like go, \ + we don't need this module. + - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(). \ + Use the inplace operation to speed up. + - last_linear_layer_init_zero (:obj:`bool`): Whether to initialize the last linear layer with zeros, \ + which can provide stable zero outputs in the beginning, defaults to True. + - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + """ + super().__init__() + self.robot_state_dim = robot_state_dim + self.human_state_dim = human_state_dim + self.hidden_channels = hidden_channels + self.similarity_function = 'embedded_gaussian' + self.robot_num = robot_num + self.human_num = human_num + self.rgcn = RGCNLayer( + robot_num=self.robot_num, + human_num=self.human_num, + robot_state_dim=self.robot_state_dim, + human_state_dim=self.human_state_dim, + similarity_function=self.similarity_function, + num_layer=2, + X_dim=hidden_channels, + final_state_dim=hidden_channels, + wr_dims=[hidden_channels, hidden_channels], # TODO: check dim + wh_dims=[hidden_channels, hidden_channels], + layerwise_graph=False, + skip_connection=True, + norm_type=None, + ) + mlp_input_shape = (robot_num + human_num) * hidden_channels + self.fc_representation = MLP( + in_channels=mlp_input_shape, + hidden_channels=hidden_channels, + out_channels=hidden_channels, + layer_num=layer_num, + activation=activation, + norm_type=norm_type, + # don't use activation and norm in the last layer of representation network is important for convergence. + output_activation=False, + output_norm=False, + # last_linear_layer_init_zero=True is beneficial for convergence speed. + last_linear_layer_init_zero=True, + ) + + def forward(self, x: Dict[str, torch.Tensor]) -> torch.Tensor: + """ + Shapes: + - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size, N is the length of vector observation. + - output (:obj:`torch.Tensor`): :math:`(B, hidden_channels)`, where B is batch size. + """ + gcn_embedding = self.rgcn(x) + gcn_embedding = gcn_embedding.view(gcn_embedding.shape[0], -1) # (B,M,N) -> (B,M*N) + return self.fc_representation(gcn_embedding) diff --git a/lzero/model/efficientzero_model_md.py b/lzero/model/efficientzero_model_md.py new file mode 100644 index 000000000..d520995d8 --- /dev/null +++ b/lzero/model/efficientzero_model_md.py @@ -0,0 +1,479 @@ +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from ding.torch_utils import MLP +from ding.utils import MODEL_REGISTRY, SequenceType +from numpy import ndarray + +from .common import EZNetworkOutput, RepresentationNetworkMLP +from .muzero_model_md import PredictionNetworkMD +from .utils import renormalize, get_params_mean, get_dynamic_mean, get_reward_mean + + +@MODEL_REGISTRY.register('EfficientZeroModelMD') +class EfficientZeroModelMD(nn.Module): + + def __init__( + self, + agent_num: int, + output_separate_logit: bool = False, + observation_shape: int = 2, + single_agent_action_size: int = 5, + action_space_size: int = 6, + lstm_hidden_size: int = 512, + latent_state_dim: int = 256, + fc_reward_layers: SequenceType = [32], + fc_value_layers: SequenceType = [32], + reward_support_size: int = 601, + value_support_size: int = 601, + proj_hid: int = 1024, + proj_out: int = 1024, + pred_hid: int = 512, + pred_out: int = 1024, + self_supervised_learning_loss: bool = True, + categorical_distribution: bool = True, + last_linear_layer_init_zero: bool = True, + state_norm: bool = False, + activation: Optional[nn.Module] = nn.ReLU(inplace=True), + norm_type: Optional[str] = 'BN', + discrete_action_encoding_type: str = 'one_hot', + res_connection_in_dynamics: bool = False, + *args, + **kwargs, + ): + """ + Overview: + The definition of the network model of EfficientZero, which is a generalization version for 1D vector obs. + The networks are mainly built on fully connected layers. + Sampled EfficientZero model consists of a representation network, a dynamics network and a prediction network. + The representation network is an MLP network which maps the raw observation to a latent state. + The dynamics network is an MLP+LSTM network which predicts the next latent state, reward_hidden_state and value_prefix given the current latent state and action. + The prediction network is an MLP network which predicts the value and policy given the current latent state. + Arguments: + - observation_shape (:obj:`int`): Observation space shape, e.g. 8 for Lunarlander. + - action_space_size: (:obj:`int`): Action space size, e.g. 4 for Lunarlander. + - lstm_hidden_size (:obj:`int`): The hidden size of LSTM in dynamics network to predict value_prefix. + - latent_state_dim (:obj:`int`): The dimension of latent state, such as 256. + - fc_reward_layers (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). + - fc_value_layers (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). + - fc_policy_layers (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). + - reward_support_size (:obj:`int`): The size of categorical reward output + - value_support_size (:obj:`int`): The size of categorical value output. + - proj_hid (:obj:`int`): The size of projection hidden layer. + - proj_out (:obj:`int`): The size of projection output layer. + - pred_hid (:obj:`int`): The size of prediction hidden layer. + - pred_out (:obj:`int`): The size of prediction output layer. + - self_supervised_learning_loss (:obj:`bool`): Whether to use self_supervised_learning related networks in Sampled EfficientZero model, default set it to False. + - categorical_distribution (:obj:`bool`): Whether to use discrete support to represent categorical distribution for value, reward/value_prefix. + - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of value/policy mlp, default sets it to True. + - state_norm (:obj:`bool`): Whether to use normalization for latent states, default sets it to True. + - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ + operation to speedup, e.g. ReLU(inplace=True). + - discrete_action_encoding_type (:obj:`str`): The type of encoding for discrete action. Default sets it to 'one_hot'. options = {'one_hot', 'not_one_hot'} + - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + - res_connection_in_dynamics (:obj:`bool`): Whether to use residual connection for dynamics network, default set it to False. + """ + super(EfficientZeroModelMD, self).__init__() + if not categorical_distribution: + self.reward_support_size = 1 + self.value_support_size = 1 + else: + self.reward_support_size = reward_support_size + self.value_support_size = value_support_size + + self.action_space_size = action_space_size + self.continuous_action_space = False + # The dim of action space. For discrete action space, it is 1. + # For continuous action space, it is the dimension of continuous action. + self.action_space_dim = action_space_size if self.continuous_action_space else 1 + assert discrete_action_encoding_type in ['one_hot', 'not_one_hot'], discrete_action_encoding_type + self.discrete_action_encoding_type = discrete_action_encoding_type + if self.continuous_action_space: + self.action_encoding_dim = action_space_size + else: + if self.discrete_action_encoding_type == 'one_hot': + self.action_encoding_dim = action_space_size + elif self.discrete_action_encoding_type == 'not_one_hot': + self.action_encoding_dim = 1 + + self.lstm_hidden_size = lstm_hidden_size + self.proj_hid = proj_hid + self.proj_out = proj_out + self.pred_hid = pred_hid + self.pred_out = pred_out + self.self_supervised_learning_loss = self_supervised_learning_loss + self.last_linear_layer_init_zero = last_linear_layer_init_zero + self.state_norm = state_norm + self.res_connection_in_dynamics = res_connection_in_dynamics + + self.representation_network = RepresentationNetworkMLP( + observation_shape=observation_shape, hidden_channels=latent_state_dim, norm_type=norm_type + ) + + self.dynamics_network = DynamicsNetworkMLP( + action_encoding_dim=self.action_encoding_dim, + num_channels=latent_state_dim + self.action_encoding_dim, + common_layer_num=2, + lstm_hidden_size=lstm_hidden_size, + fc_reward_layers=fc_reward_layers, + output_support_size=self.reward_support_size, + last_linear_layer_init_zero=self.last_linear_layer_init_zero, + norm_type=norm_type, + res_connection_in_dynamics=self.res_connection_in_dynamics, + ) + + self.prediction_network = PredictionNetworkMD( + agent_num=agent_num, + single_agent_action_size=single_agent_action_size, + num_channels=latent_state_dim, + fc_value_layers=fc_value_layers, + output_support_size=self.value_support_size, + last_linear_layer_init_zero=self.last_linear_layer_init_zero, + norm_type=norm_type, + output_separate_logit=output_separate_logit, + ) + + if self.self_supervised_learning_loss: + # self_supervised_learning_loss related network proposed in EfficientZero + self.projection_input_dim = latent_state_dim + + self.projection = nn.Sequential( + nn.Linear(self.projection_input_dim, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, + nn.Linear(self.proj_hid, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, + nn.Linear(self.proj_hid, self.proj_out), nn.BatchNorm1d(self.proj_out) + ) + self.prediction_head = nn.Sequential( + nn.Linear(self.proj_out, self.pred_hid), + nn.BatchNorm1d(self.pred_hid), + activation, + nn.Linear(self.pred_hid, self.pred_out), + ) + + def initial_inference(self, obs: torch.Tensor) -> EZNetworkOutput: + """ + Overview: + Initial inference of EfficientZero model, which is the first step of the EfficientZero model. + To perform the initial inference, we first use the representation network to obtain the "latent_state" of the observation. + Then we use the prediction network to predict the "value" and "policy_logits" of the "latent_state", and + also prepare the zeros-like ``reward_hidden_state`` for the next step of the EfficientZero model. + Arguments: + - obs (:obj:`torch.Tensor`): The 1D vector observation data. + Returns (EZNetworkOutput): + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + - value_prefix (:obj:`torch.Tensor`): The predicted prefix sum of value for input state. \ + In initial inference, we set it to zero vector. + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The hidden state of LSTM about reward. In initial inference, \ + we set it to the zeros-like hidden state (H and C). + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, obs_shape)`, where B is batch_size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. + - value_prefix (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The shape of each element is :math:`(1, B, lstm_hidden_size)`, where B is batch_size. + """ + batch_size = obs.size(0) + latent_state = self._representation(obs) + policy_logits, value = self._prediction(latent_state) + # zero initialization for reward hidden states + # (hn, cn), each element shape is (layer_num=1, batch_size, lstm_hidden_size) + reward_hidden_state = ( + torch.zeros(1, batch_size, + self.lstm_hidden_size).to(obs.device), torch.zeros(1, batch_size, + self.lstm_hidden_size).to(obs.device) + ) + return EZNetworkOutput(value, [0. for _ in range(batch_size)], policy_logits, latent_state, reward_hidden_state) + + def recurrent_inference( + self, latent_state: torch.Tensor, reward_hidden_state: torch.Tensor, action: torch.Tensor + ) -> EZNetworkOutput: + """ + Overview: + Recurrent inference of EfficientZero model, which is the rollout step of the EfficientZero model. + To perform the recurrent inference, we first use the dynamics network to predict ``next_latent_state``, + ``reward_hidden_state``, ``value_prefix`` by the given current ``latent_state`` and ``action``. + We then use the prediction network to predict the ``value`` and ``policy_logits``. + Arguments: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The input hidden state of LSTM about reward. + - action (:obj:`torch.Tensor`): The predicted action to rollout. + Returns (EZNetworkOutput): + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + - value_prefix (:obj:`torch.Tensor`): The predicted prefix sum of value for input state. + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - next_latent_state (:obj:`torch.Tensor`): The predicted next latent state. + - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The output hidden state of LSTM about reward. + Shapes: + - action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. + - value_prefix (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - next_latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The shape of each element is :math:`(1, B, lstm_hidden_size)`, where B is batch_size. + """ + next_latent_state, reward_hidden_state, value_prefix = self._dynamics(latent_state, reward_hidden_state, action) + policy_logits, value = self._prediction(next_latent_state) + return EZNetworkOutput(value, value_prefix, policy_logits, next_latent_state, reward_hidden_state) + + def _representation(self, observation: torch.Tensor) -> Tuple[torch.Tensor]: + """ + Overview: + Use the representation network to encode the observations into latent state. + Arguments: + - obs (:obj:`torch.Tensor`): The 1D vector observation data. + Returns: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, obs_shape)`, where B is batch_size. + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + """ + observation = observation.float() + latent_state = self.representation_network(observation) + if self.state_norm: + latent_state = renormalize(latent_state) + return latent_state + + def _prediction(self, latent_state: torch.Tensor) -> Tuple[torch.Tensor]: + """ + Overview: + Use the representation network to encode the observations into latent state. + Arguments: + - obs (:obj:`torch.Tensor`): The 1D vector observation data. + Returns: + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + Shapes: + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. + """ + policy_logits, value = self.prediction_network(latent_state) + return policy_logits, value + + def _dynamics(self, latent_state: torch.Tensor, reward_hidden_state: Tuple, + action: torch.Tensor) -> Tuple[torch.Tensor, Tuple[torch.Tensor], torch.Tensor]: + """ + Overview: + Concatenate ``latent_state`` and ``action`` and use the dynamics network to predict ``next_latent_state`` + ``value_prefix`` and ``next_reward_hidden_state``. + Arguments: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The input hidden state of LSTM about reward. + - action (:obj:`torch.Tensor`): The predicted action to rollout. + Returns: + - next_latent_state (:obj:`torch.Tensor`): The predicted latent state of the next timestep. + - next_reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The output hidden state of LSTM about reward. + - value_prefix (:obj:`torch.Tensor`): The predicted prefix sum of value for input state. + Shapes: + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size. + - next_latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - value_prefix (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + """ + # NOTE: the discrete action encoding type is important for some environments + + # discrete action space + if self.discrete_action_encoding_type == 'one_hot': + # Stack latent_state with the one hot encoded action + if len(action.shape) == 1: + # (batch_size, ) -> (batch_size, 1) + # e.g., torch.Size([8]) -> torch.Size([8, 1]) + action = action.unsqueeze(-1) + + # transform action to one-hot encoding. + # action_one_hot shape: (batch_size, action_space_size), e.g., (8, 4) + action_one_hot = torch.zeros(action.shape[0], self.action_space_size, device=action.device) + # transform action to torch.int64 + action = action.long() + action_one_hot.scatter_(1, action, 1) + action_encoding = action_one_hot + elif self.discrete_action_encoding_type == 'not_one_hot': + action_encoding = action / self.action_space_size + if len(action_encoding.shape) == 1: + # (batch_size, ) -> (batch_size, 1) + # e.g., torch.Size([8]) -> torch.Size([8, 1]) + action_encoding = action_encoding.unsqueeze(-1) + + action_encoding = action_encoding.to(latent_state.device).float() + # state_action_encoding shape: (batch_size, latent_state[1] + action_dim]) or + # (batch_size, latent_state[1] + action_space_size]) depending on the discrete_action_encoding_type. + state_action_encoding = torch.cat((latent_state, action_encoding), dim=1) + + # NOTE: the key difference with MuZero + next_latent_state, next_reward_hidden_state, value_prefix = self.dynamics_network( + state_action_encoding, reward_hidden_state + ) + + if self.state_norm: + next_latent_state = renormalize(next_latent_state) + return next_latent_state, next_reward_hidden_state, value_prefix + + def project(self, latent_state: torch.Tensor, with_grad=True): + """ + Overview: + Project the latent state to a lower dimension to calculate the self-supervised loss, which is proposed in EfficientZero. + For more details, please refer to the paper ``Exploring Simple Siamese Representation Learning``. + Arguments: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - with_grad (:obj:`bool`): Whether to calculate gradient for the projection result. + Returns: + - proj (:obj:`torch.Tensor`): The result embedding vector of projection operation. + Shapes: + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - proj (:obj:`torch.Tensor`): :math:`(B, projection_output_dim)`, where B is batch_size. + + Examples: + >>> latent_state = torch.randn(256, 64) + >>> output = self.project(latent_state) + >>> output.shape # (256, 1024) + """ + proj = self.projection(latent_state) + + if with_grad: + # with grad, use prediction_head + return self.prediction_head(proj) + else: + return proj.detach() + + def get_params_mean(self) -> float: + return get_params_mean(self) + + +class DynamicsNetworkMLP(nn.Module): + + def __init__( + self, + action_encoding_dim: int = 2, + num_channels: int = 64, + common_layer_num: int = 2, + fc_reward_layers: SequenceType = [32], + output_support_size: int = 601, + lstm_hidden_size: int = 512, + last_linear_layer_init_zero: bool = True, + activation: Optional[nn.Module] = nn.ReLU(inplace=True), + norm_type: Optional[str] = 'BN', + res_connection_in_dynamics: bool = False, + ): + """ + Overview: + The definition of dynamics network in EfficientZero algorithm, which is used to predict next latent state + value_prefix and reward_hidden_state by the given current latent state and action. + The networks are mainly built on fully connected layers. + Arguments: + - action_encoding_dim (:obj:`int`): The dimension of action encoding. + - num_channels (:obj:`int`): The num of channels in latent states. + - common_layer_num (:obj:`int`): The number of common layers in dynamics network. + - fc_reward_layers (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). + - output_support_size (:obj:`int`): The size of categorical reward output. + - lstm_hidden_size (:obj:`int`): The hidden size of lstm in dynamics network. + - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializationss for the last layer of value/policy head, default sets it to True. + - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ + operation to speedup, e.g. ReLU(inplace=True). + - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + - res_connection_in_dynamics (:obj:`bool`): Whether to use residual connection in dynamics network. + """ + super().__init__() + assert num_channels > action_encoding_dim, f'num_channels:{num_channels} <= action_encoding_dim:{action_encoding_dim}' + + self.num_channels = num_channels + self.action_encoding_dim = action_encoding_dim + self.latent_state_dim = self.num_channels - self.action_encoding_dim + self.lstm_hidden_size = lstm_hidden_size + self.activation = activation + self.res_connection_in_dynamics = res_connection_in_dynamics + + if self.res_connection_in_dynamics: + self.fc_dynamics_1 = MLP( + in_channels=self.num_channels, + hidden_channels=self.latent_state_dim, + layer_num=common_layer_num, + out_channels=self.latent_state_dim, + activation=activation, + norm_type=norm_type, + output_activation=True, + output_norm=True, + # last_linear_layer_init_zero=False is important for convergence + last_linear_layer_init_zero=False, + ) + self.fc_dynamics_2 = MLP( + in_channels=self.latent_state_dim, + hidden_channels=self.latent_state_dim, + layer_num=common_layer_num, + out_channels=self.latent_state_dim, + activation=activation, + norm_type=norm_type, + output_activation=True, + output_norm=True, + # last_linear_layer_init_zero=False is important for convergence + last_linear_layer_init_zero=False, + ) + else: + self.fc_dynamics = MLP( + in_channels=self.num_channels, + hidden_channels=self.latent_state_dim, + layer_num=common_layer_num, + out_channels=self.latent_state_dim, + activation=activation, + norm_type=norm_type, + output_activation=True, + output_norm=True, + # last_linear_layer_init_zero=False is important for convergence + last_linear_layer_init_zero=False, + ) + + # input_shape: (sequence_length,batch_size,input_size) + # output_shape: (sequence_length, batch_size, hidden_size) + self.lstm = nn.LSTM(input_size=self.latent_state_dim, hidden_size=self.lstm_hidden_size) + + self.fc_reward_head = MLP( + in_channels=self.lstm_hidden_size, + hidden_channels=fc_reward_layers[0], + layer_num=2, + out_channels=output_support_size, + activation=self.activation, + norm_type=norm_type, + output_activation=False, + output_norm=False, + last_linear_layer_init_zero=last_linear_layer_init_zero + ) + + def forward(self, state_action_encoding: torch.Tensor, reward_hidden_state): + """ + Overview: + Forward computation of the dynamics network. Predict next latent state given current state_action_encoding and reward hidden state. + Arguments: + - state_action_encoding (:obj:`torch.Tensor`): The state-action encoding, which is the concatenation of \ + latent state and action encoding, with shape (batch_size, num_channels, height, width). + - reward_hidden_state (:obj:`Tuple[torch.Tensor, torch.Tensor]`): The input hidden state of LSTM about reward. + Returns: + - next_latent_state (:obj:`torch.Tensor`): The next latent state, with shape (batch_size, latent_state_dim). + - next_reward_hidden_state (:obj:`torch.Tensor`): The input hidden state of LSTM about reward. + - value_prefix (:obj:`torch.Tensor`): The predicted prefix sum of value for input state. + """ + if self.res_connection_in_dynamics: + # take the state encoding (latent_state), state_action_encoding[:, -self.action_encoding_dim] + # is action encoding + latent_state = state_action_encoding[:, :-self.action_encoding_dim] + x = self.fc_dynamics_1(state_action_encoding) + # the residual link: add state encoding to the state_action encoding + next_latent_state = x + latent_state + next_latent_state_ = self.fc_dynamics_2(next_latent_state) + else: + next_latent_state = self.fc_dynamics(state_action_encoding) + next_latent_state_ = next_latent_state + + next_latent_state_unsqueeze = next_latent_state_.unsqueeze(0) + value_prefix, next_reward_hidden_state = self.lstm(next_latent_state_unsqueeze, reward_hidden_state) + value_prefix = self.fc_reward_head(value_prefix.squeeze(0)) + + return next_latent_state, next_reward_hidden_state, value_prefix + + def get_dynamic_mean(self) -> float: + return get_dynamic_mean(self) + + def get_reward_mean(self) -> Tuple[ndarray, float]: + return get_reward_mean(self) diff --git a/lzero/model/muzero_model_gcn.py b/lzero/model/muzero_model_gcn.py new file mode 100644 index 000000000..08770a880 --- /dev/null +++ b/lzero/model/muzero_model_gcn.py @@ -0,0 +1,456 @@ +from typing import Optional, Tuple, Dict + +import torch +import torch.nn as nn +from ding.torch_utils import MLP +from ding.utils import MODEL_REGISTRY, SequenceType + +from .common import MZNetworkOutput, RepresentationNetworkMLP, PredictionNetworkMLP +from .common_gcn import RepresentationNetworkGCN +from .utils import renormalize, get_params_mean, get_dynamic_mean, get_reward_mean + + +@MODEL_REGISTRY.register('MuZeroModelGCN') +class MuZeroModelGCN(nn.Module): + + def __init__( + self, + robot_state_dim: int, + human_state_dim: int, + robot_num: int, + human_num: int, + action_space_size: int, + latent_state_dim: int = 64, + fc_reward_layers: SequenceType = [32], + fc_value_layers: SequenceType = [32], + fc_policy_layers: SequenceType = [32], + reward_support_size: int = 601, + value_support_size: int = 601, + proj_hid: int = 1024, + proj_out: int = 1024, + pred_hid: int = 512, + pred_out: int = 1024, + self_supervised_learning_loss: bool = False, + categorical_distribution: bool = True, + activation: Optional[nn.Module] = nn.ReLU(inplace=True), + last_linear_layer_init_zero: bool = True, + state_norm: bool = False, + discrete_action_encoding_type: str = 'one_hot', + norm_type: Optional[str] = 'BN', + res_connection_in_dynamics: bool = False, + *args, + **kwargs + ): + """ + Overview: + The definition of the network model of MuZero, which is a generalization version for 1D vector obs. + The networks are mainly built on fully connected layers. + The representation network is an MLP network which maps the raw observation to a latent state. + The dynamics network is an MLP network which predicts the next latent state, and reward given the current latent state and action. + The prediction network is an MLP network which predicts the value and policy given the current latent state. + Arguments: + - observation_shape (:obj:`int`): Observation space shape, e.g. 8 for Lunarlander. + - action_space_size: (:obj:`int`): Action space size, e.g. 4 for Lunarlander. + - latent_state_dim (:obj:`int`): The dimension of latent state, such as 256. + - fc_reward_layers (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). + - fc_value_layers (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). + - fc_policy_layers (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). + - reward_support_size (:obj:`int`): The size of categorical reward output + - value_support_size (:obj:`int`): The size of categorical value output. + - proj_hid (:obj:`int`): The size of projection hidden layer. + - proj_out (:obj:`int`): The size of projection output layer. + - pred_hid (:obj:`int`): The size of prediction hidden layer. + - pred_out (:obj:`int`): The size of prediction output layer. + - self_supervised_learning_loss (:obj:`bool`): Whether to use self_supervised_learning related networks in MuZero model, default set it to False. + - categorical_distribution (:obj:`bool`): Whether to use discrete support to represent categorical distribution for value, reward/value_prefix. + - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ + operation to speedup, e.g. ReLU(inplace=True). + - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of value/policy mlp, default sets it to True. + - state_norm (:obj:`bool`): Whether to use normalization for latent states, default sets it to True. + - discrete_action_encoding_type (:obj:`str`): The encoding type of discrete action, which can be 'one_hot' or 'not_one_hot'. + - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + - res_connection_in_dynamics (:obj:`bool`): Whether to use residual connection for dynamics network, default set it to False. + """ + super(MuZeroModelGCN, self).__init__() + self.categorical_distribution = categorical_distribution + if not self.categorical_distribution: + self.reward_support_size = 1 + self.value_support_size = 1 + else: + self.reward_support_size = reward_support_size + self.value_support_size = value_support_size + + self.action_space_size = action_space_size + self.continuous_action_space = False + # The dim of action space. For discrete action space, it is 1. + # For continuous action space, it is the dimension of continuous action. + self.action_space_dim = action_space_size if self.continuous_action_space else 1 + assert discrete_action_encoding_type in ['one_hot', 'not_one_hot'], discrete_action_encoding_type + self.discrete_action_encoding_type = discrete_action_encoding_type + if self.continuous_action_space: + self.action_encoding_dim = action_space_size + else: + if self.discrete_action_encoding_type == 'one_hot': + self.action_encoding_dim = action_space_size + elif self.discrete_action_encoding_type == 'not_one_hot': + self.action_encoding_dim = 1 + + self.latent_state_dim = latent_state_dim + self.proj_hid = proj_hid + self.proj_out = proj_out + self.pred_hid = pred_hid + self.pred_out = pred_out + self.self_supervised_learning_loss = self_supervised_learning_loss + self.last_linear_layer_init_zero = last_linear_layer_init_zero + self.state_norm = state_norm + self.res_connection_in_dynamics = res_connection_in_dynamics + + self.representation_network = RepresentationNetworkGCN( + robot_state_dim=robot_state_dim, + human_state_dim=human_state_dim, + robot_num=robot_num, + human_num=human_num, + hidden_channels=self.latent_state_dim, + layer_num=2, + norm_type=norm_type + ) + + self.dynamics_network = DynamicsNetwork( + action_encoding_dim=self.action_encoding_dim, + num_channels=self.latent_state_dim + self.action_encoding_dim, + common_layer_num=2, + fc_reward_layers=fc_reward_layers, + output_support_size=self.reward_support_size, + last_linear_layer_init_zero=self.last_linear_layer_init_zero, + norm_type=norm_type, + res_connection_in_dynamics=self.res_connection_in_dynamics, + ) + + self.prediction_network = PredictionNetworkMLP( + action_space_size=action_space_size, + num_channels=latent_state_dim, + fc_value_layers=fc_value_layers, + fc_policy_layers=fc_policy_layers, + output_support_size=self.value_support_size, + last_linear_layer_init_zero=self.last_linear_layer_init_zero, + norm_type=norm_type + ) + + if self.self_supervised_learning_loss: + # self_supervised_learning_loss related network proposed in EfficientZero + self.projection_input_dim = latent_state_dim + + self.projection = nn.Sequential( + nn.Linear(self.projection_input_dim, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, + nn.Linear(self.proj_hid, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, + nn.Linear(self.proj_hid, self.proj_out), nn.BatchNorm1d(self.proj_out) + ) + self.prediction_head = nn.Sequential( + nn.Linear(self.proj_out, self.pred_hid), + nn.BatchNorm1d(self.pred_hid), + activation, + nn.Linear(self.pred_hid, self.pred_out), + ) + + def initial_inference(self, obs: torch.Tensor) -> MZNetworkOutput: + """ + Overview: + Initial inference of MuZero model, which is the first step of the MuZero model. + To perform the initial inference, we first use the representation network to obtain the "latent_state" of the observation. + Then we use the prediction network to predict the "value" and "policy_logits" of the "latent_state", and + also prepare the zeros-like ``reward`` for the next step of the MuZero model. + Arguments: + - obs (:obj:`torch.Tensor`): The 1D vector observation data. + Returns (MZNetworkOutput): + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + - value_prefix (:obj:`torch.Tensor`): The predicted prefix sum of value for input state. \ + In initial inference, we set it to zero vector. + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The hidden state of LSTM about reward. In initial inference, \ + we set it to the zeros-like hidden state (H and C). + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, obs_shape)`, where B is batch_size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. + - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + """ + batch_size = obs.size(0) + latent_state = self._representation(obs) + policy_logits, value = self._prediction(latent_state) + return MZNetworkOutput( + value, + [0. for _ in range(batch_size)], + policy_logits, + latent_state, + ) + + def recurrent_inference(self, latent_state: torch.Tensor, action: torch.Tensor) -> MZNetworkOutput: + """ + Overview: + Recurrent inference of MuZero model, which is the rollout step of the MuZero model. + To perform the recurrent inference, we first use the dynamics network to predict ``next_latent_state``, + ``reward`` by the given current ``latent_state`` and ``action``. + We then use the prediction network to predict the ``value`` and ``policy_logits``. + Arguments: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input obs. + - action (:obj:`torch.Tensor`): The predicted action to rollout. + Returns (MZNetworkOutput): + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + - reward (:obj:`torch.Tensor`): The predicted reward for input state. + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - next_latent_state (:obj:`torch.Tensor`): The predicted next latent state. + Shapes: + - action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. + - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - next_latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + """ + next_latent_state, reward = self._dynamics(latent_state, action) + policy_logits, value = self._prediction(next_latent_state) + return MZNetworkOutput(value, reward, policy_logits, next_latent_state) + + def _representation(self, observation: torch.Tensor) -> Tuple[torch.Tensor]: + """ + Overview: + Use the representation network to encode the observations into latent state. + Arguments: + - obs (:obj:`torch.Tensor`): The 1D vector observation data. + Returns: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, obs_shape)`, where B is batch_size. + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + """ + latent_state = self.representation_network(observation) + if self.state_norm: + latent_state = renormalize(latent_state) + return latent_state + + def _prediction(self, latent_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Use the representation network to encode the observations into latent state. + Arguments: + - obs (:obj:`torch.Tensor`): The 1D vector observation data. + Returns: + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + Shapes: + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. + """ + policy_logits, value = self.prediction_network(latent_state) + return policy_logits, value + + def _dynamics(self, latent_state: torch.Tensor, action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Concatenate ``latent_state`` and ``action`` and use the dynamics network to predict ``next_latent_state`` + ``reward`` and ``next_reward_hidden_state``. + Arguments: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The input hidden state of LSTM about reward. + - action (:obj:`torch.Tensor`): The predicted action to rollout. + Returns: + - next_latent_state (:obj:`torch.Tensor`): The predicted latent state of the next timestep. + - next_reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The output hidden state of LSTM about reward. + - reward (:obj:`torch.Tensor`): The predicted reward for input state. + Shapes: + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size. + - next_latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + """ + # NOTE: the discrete action encoding type is important for some environments + + # discrete action space + if self.discrete_action_encoding_type == 'one_hot': + # Stack latent_state with the one hot encoded action + if len(action.shape) == 1: + # (batch_size, ) -> (batch_size, 1) + # e.g., torch.Size([8]) -> torch.Size([8, 1]) + action = action.unsqueeze(-1) + + # transform action to one-hot encoding. + # action_one_hot shape: (batch_size, action_space_size), e.g., (8, 4) + action_one_hot = torch.zeros(action.shape[0], self.action_space_size, device=action.device) + # transform action to torch.int64 + action = action.long() + action_one_hot.scatter_(1, action, 1) + action_encoding = action_one_hot + elif self.discrete_action_encoding_type == 'not_one_hot': + action_encoding = action / self.action_space_size + if len(action_encoding.shape) == 1: + # (batch_size, ) -> (batch_size, 1) + # e.g., torch.Size([8]) -> torch.Size([8, 1]) + action_encoding = action_encoding.unsqueeze(-1) + + action_encoding = action_encoding.to(latent_state.device).float() + # state_action_encoding shape: (batch_size, latent_state[1] + action_dim]) or + # (batch_size, latent_state[1] + action_space_size]) depending on the discrete_action_encoding_type. + state_action_encoding = torch.cat((latent_state, action_encoding), dim=1) + + next_latent_state, reward = self.dynamics_network(state_action_encoding) + + if not self.state_norm: + return next_latent_state, reward + else: + next_latent_state_normalized = renormalize(next_latent_state) + return next_latent_state_normalized, reward + + def project(self, latent_state: torch.Tensor, with_grad=True) -> torch.Tensor: + """ + Overview: + Project the latent state to a lower dimension to calculate the self-supervised loss, which is proposed in EfficientZero. + For more details, please refer to the paper ``Exploring Simple Siamese Representation Learning``. + Arguments: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - with_grad (:obj:`bool`): Whether to calculate gradient for the projection result. + Returns: + - proj (:obj:`torch.Tensor`): The result embedding vector of projection operation. + Shapes: + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - proj (:obj:`torch.Tensor`): :math:`(B, projection_output_dim)`, where B is batch_size. + + Examples: + >>> latent_state = torch.randn(256, 64) + >>> output = self.project(latent_state) + >>> output.shape # (256, 1024) + """ + proj = self.projection(latent_state) + + if with_grad: + # with grad, use prediction_head + return self.prediction_head(proj) + else: + return proj.detach() + + def get_params_mean(self) -> float: + return get_params_mean(self) + + +class DynamicsNetwork(nn.Module): + + def __init__( + self, + action_encoding_dim: int = 2, + num_channels: int = 64, + common_layer_num: int = 2, + fc_reward_layers: SequenceType = [32], + output_support_size: int = 601, + last_linear_layer_init_zero: bool = True, + activation: Optional[nn.Module] = nn.ReLU(inplace=True), + norm_type: Optional[str] = 'BN', + res_connection_in_dynamics: bool = False, + ): + """ + Overview: + The definition of dynamics network in MuZero algorithm, which is used to predict next latent state + reward by the given current latent state and action. + The networks are mainly built on fully connected layers. + Arguments: + - action_encoding_dim (:obj:`int`): The dimension of action encoding. + - num_channels (:obj:`int`): The num of channels in latent states. + - common_layer_num (:obj:`int`): The number of common layers in dynamics network. + - fc_reward_layers (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). + - output_support_size (:obj:`int`): The size of categorical reward output. + - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of value/policy mlp, default sets it to True. + - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ + operation to speedup, e.g. ReLU(inplace=True). + - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + - res_connection_in_dynamics (:obj:`bool`): Whether to use residual connection in dynamics network. + """ + super().__init__() + self.num_channels = num_channels + self.action_encoding_dim = action_encoding_dim + self.latent_state_dim = self.num_channels - self.action_encoding_dim + + self.res_connection_in_dynamics = res_connection_in_dynamics + if self.res_connection_in_dynamics: + self.fc_dynamics_1 = MLP( + in_channels=self.num_channels, + hidden_channels=self.latent_state_dim, + layer_num=common_layer_num, + out_channels=self.latent_state_dim, + activation=activation, + norm_type=norm_type, + output_activation=True, + output_norm=True, + # last_linear_layer_init_zero=False is important for convergence + last_linear_layer_init_zero=False, + ) + self.fc_dynamics_2 = MLP( + in_channels=self.latent_state_dim, + hidden_channels=self.latent_state_dim, + layer_num=common_layer_num, + out_channels=self.latent_state_dim, + activation=activation, + norm_type=norm_type, + output_activation=True, + output_norm=True, + # last_linear_layer_init_zero=False is important for convergence + last_linear_layer_init_zero=False, + ) + else: + self.fc_dynamics = MLP( + in_channels=self.num_channels, + hidden_channels=self.latent_state_dim, + layer_num=common_layer_num, + out_channels=self.latent_state_dim, + activation=activation, + norm_type=norm_type, + output_activation=True, + output_norm=True, + # last_linear_layer_init_zero=False is important for convergence + last_linear_layer_init_zero=False, + ) + + self.fc_reward_head = MLP( + in_channels=self.latent_state_dim, + hidden_channels=fc_reward_layers[0], + layer_num=2, + out_channels=output_support_size, + activation=activation, + norm_type=norm_type, + output_activation=False, + output_norm=False, + last_linear_layer_init_zero=last_linear_layer_init_zero + ) + + def forward(self, state_action_encoding: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Forward computation of the dynamics network. Predict the next latent state given current latent state and action. + Arguments: + - state_action_encoding (:obj:`torch.Tensor`): The state-action encoding, which is the concatenation of \ + latent state and action encoding, with shape (batch_size, num_channels, height, width). + Returns: + - next_latent_state (:obj:`torch.Tensor`): The next latent state, with shape (batch_size, latent_state_dim). + - reward (:obj:`torch.Tensor`): The predicted reward for input state. + """ + if self.res_connection_in_dynamics: + # take the state encoding (e.g. latent_state), + # state_action_encoding[:, -self.action_encoding_dim:] is action encoding + latent_state = state_action_encoding[:, :-self.action_encoding_dim] + x = self.fc_dynamics_1(state_action_encoding) + # the residual link: add the latent_state to the state_action encoding + next_latent_state = x + latent_state + next_latent_state_encoding = self.fc_dynamics_2(next_latent_state) + else: + next_latent_state = self.fc_dynamics(state_action_encoding) + next_latent_state_encoding = next_latent_state + + reward = self.fc_reward_head(next_latent_state_encoding) + + return next_latent_state, reward + + def get_dynamic_mean(self) -> float: + return get_dynamic_mean(self) + + def get_reward_mean(self) -> float: + return get_reward_mean(self) diff --git a/lzero/model/muzero_model_md.py b/lzero/model/muzero_model_md.py new file mode 100644 index 000000000..cca5f863e --- /dev/null +++ b/lzero/model/muzero_model_md.py @@ -0,0 +1,557 @@ +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from ding.torch_utils import MLP +from ding.utils import MODEL_REGISTRY, SequenceType +from ding.model.common.head import MultiHead, DiscreteHead + +from .common import MZNetworkOutput, RepresentationNetworkMLP +from .utils import renormalize, get_params_mean, get_dynamic_mean, get_reward_mean + + +@MODEL_REGISTRY.register('MuZeroModelMD') +class MuZeroModelMD(nn.Module): + + def __init__( + self, + agent_num: int, + output_separate_logit: bool = False, + observation_shape: int = 2, + single_agent_action_size: int = 5, + action_space_size: int = 6, + latent_state_dim: int = 256, + fc_reward_layers: SequenceType = [32], + fc_value_layers: SequenceType = [32], + reward_support_size: int = 601, + value_support_size: int = 601, + proj_hid: int = 1024, + proj_out: int = 1024, + pred_hid: int = 512, + pred_out: int = 1024, + self_supervised_learning_loss: bool = False, + categorical_distribution: bool = True, + activation: Optional[nn.Module] = nn.ReLU(inplace=True), + last_linear_layer_init_zero: bool = True, + state_norm: bool = False, + discrete_action_encoding_type: str = 'one_hot', + norm_type: Optional[str] = 'BN', + res_connection_in_dynamics: bool = False, + *args, + **kwargs + ): + """ + Overview: + The definition of the network model of MuZero, which is a generalization version for 1D vector obs. + The networks are mainly built on fully connected layers. + The representation network is an MLP network which maps the raw observation to a latent state. + The dynamics network is an MLP network which predicts the next latent state, and reward given the current latent state and action. + The prediction network is an network with agent_num multihead which predicts the value and policy given the current latent state. + Arguments: + - agent_num (:obj:`int`): The number of agents in the environment. + - output_separate_logit (:obj:`bool`): Whether to output separate logit for each action. + - observation_shape (:obj:`int`): Observation space shape, e.g. 8 for Lunarlander. + - action_space_size: (:obj:`int`): Combinational action space size. + - single_agent_action_size: (:obj:`int`): The size of action space for single agent. + - latent_state_dim (:obj:`int`): The dimension of latent state, such as 256. + - fc_reward_layers (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). + - fc_value_layers (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). + - reward_support_size (:obj:`int`): The size of categorical reward output + - value_support_size (:obj:`int`): The size of categorical value output. + - proj_hid (:obj:`int`): The size of projection hidden layer. + - proj_out (:obj:`int`): The size of projection output layer. + - pred_hid (:obj:`int`): The size of prediction hidden layer. + - pred_out (:obj:`int`): The size of prediction output layer. + - self_supervised_learning_loss (:obj:`bool`): Whether to use self_supervised_learning related networks in MuZero model, default set it to False. + - categorical_distribution (:obj:`bool`): Whether to use discrete support to represent categorical distribution for value, reward/value_prefix. + - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ + operation to speedup, e.g. ReLU(inplace=True). + - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of value/policy mlp, default sets it to True. + - state_norm (:obj:`bool`): Whether to use normalization for latent states, default sets it to True. + - discrete_action_encoding_type (:obj:`str`): The encoding type of discrete action, which can be 'one_hot' or 'not_one_hot'. + - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + - res_connection_in_dynamics (:obj:`bool`): Whether to use residual connection for dynamics network, default set it to False. + """ + super(MuZeroModelMD, self).__init__() + self.categorical_distribution = categorical_distribution + if not self.categorical_distribution: + self.reward_support_size = 1 + self.value_support_size = 1 + else: + self.reward_support_size = reward_support_size + self.value_support_size = value_support_size + + self.action_space_size = action_space_size + self.continuous_action_space = False + # The dim of action space. For discrete action space, it is 1. + # For continuous action space, it is the dimension of continuous action. + self.action_space_dim = action_space_size if self.continuous_action_space else 1 + assert discrete_action_encoding_type in ['one_hot', 'not_one_hot'], discrete_action_encoding_type + self.discrete_action_encoding_type = discrete_action_encoding_type + if self.continuous_action_space: + self.action_encoding_dim = action_space_size + else: + if self.discrete_action_encoding_type == 'one_hot': + self.action_encoding_dim = action_space_size + elif self.discrete_action_encoding_type == 'not_one_hot': + self.action_encoding_dim = 1 + + self.latent_state_dim = latent_state_dim + self.proj_hid = proj_hid + self.proj_out = proj_out + self.pred_hid = pred_hid + self.pred_out = pred_out + self.self_supervised_learning_loss = self_supervised_learning_loss + self.last_linear_layer_init_zero = last_linear_layer_init_zero + self.state_norm = state_norm + self.res_connection_in_dynamics = res_connection_in_dynamics + + self.representation_network = RepresentationNetworkMLP( + observation_shape=observation_shape, hidden_channels=self.latent_state_dim, norm_type=norm_type + ) + + self.dynamics_network = DynamicsNetwork( + action_encoding_dim=self.action_encoding_dim, + num_channels=self.latent_state_dim + self.action_encoding_dim, + common_layer_num=2, + fc_reward_layers=fc_reward_layers, + output_support_size=self.reward_support_size, + last_linear_layer_init_zero=self.last_linear_layer_init_zero, + norm_type=norm_type, + res_connection_in_dynamics=self.res_connection_in_dynamics, + ) + + self.prediction_network = PredictionNetworkMD( + agent_num=agent_num, + single_agent_action_size=single_agent_action_size, + num_channels=latent_state_dim, + fc_value_layers=fc_value_layers, + output_support_size=self.value_support_size, + last_linear_layer_init_zero=self.last_linear_layer_init_zero, + norm_type=norm_type, + output_separate_logit=output_separate_logit, + ) + + if self.self_supervised_learning_loss: + # self_supervised_learning_loss related network proposed in EfficientZero + self.projection_input_dim = latent_state_dim + + self.projection = nn.Sequential( + nn.Linear(self.projection_input_dim, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, + nn.Linear(self.proj_hid, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, + nn.Linear(self.proj_hid, self.proj_out), nn.BatchNorm1d(self.proj_out) + ) + self.prediction_head = nn.Sequential( + nn.Linear(self.proj_out, self.pred_hid), + nn.BatchNorm1d(self.pred_hid), + activation, + nn.Linear(self.pred_hid, self.pred_out), + ) + + def initial_inference(self, obs: torch.Tensor) -> MZNetworkOutput: + """ + Overview: + Initial inference of MuZero model, which is the first step of the MuZero model. + To perform the initial inference, we first use the representation network to obtain the "latent_state" of the observation. + Then we use the prediction network to predict the "value" and "policy_logits" of the "latent_state", and + also prepare the zeros-like ``reward`` for the next step of the MuZero model. + Arguments: + - obs (:obj:`torch.Tensor`): The 1D vector observation data. + Returns (MZNetworkOutput): + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + - value_prefix (:obj:`torch.Tensor`): The predicted prefix sum of value for input state. \ + In initial inference, we set it to zero vector. + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The hidden state of LSTM about reward. In initial inference, \ + we set it to the zeros-like hidden state (H and C). + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, obs_shape)`, where B is batch_size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. + - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + """ + batch_size = obs.size(0) + latent_state = self._representation(obs) + policy_logits, value = self._prediction(latent_state) + return MZNetworkOutput( + value, + [0. for _ in range(batch_size)], + policy_logits, + latent_state, + ) + + def recurrent_inference(self, latent_state: torch.Tensor, action: torch.Tensor) -> MZNetworkOutput: + """ + Overview: + Recurrent inference of MuZero model, which is the rollout step of the MuZero model. + To perform the recurrent inference, we first use the dynamics network to predict ``next_latent_state``, + ``reward`` by the given current ``latent_state`` and ``action``. + We then use the prediction network to predict the ``value`` and ``policy_logits``. + Arguments: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input obs. + - action (:obj:`torch.Tensor`): The predicted action to rollout. + Returns (MZNetworkOutput): + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + - reward (:obj:`torch.Tensor`): The predicted reward for input state. + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - next_latent_state (:obj:`torch.Tensor`): The predicted next latent state. + Shapes: + - action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. + - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - next_latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + """ + next_latent_state, reward = self._dynamics(latent_state, action) + policy_logits, value = self._prediction(next_latent_state) + return MZNetworkOutput(value, reward, policy_logits, next_latent_state) + + def _representation(self, observation: torch.Tensor) -> Tuple[torch.Tensor]: + """ + Overview: + Use the representation network to encode the observations into latent state. + Arguments: + - obs (:obj:`torch.Tensor`): The 1D vector observation data. + Returns: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, obs_shape)`, where B is batch_size. + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + """ + observation = observation.float() + latent_state = self.representation_network(observation) + if self.state_norm: + latent_state = renormalize(latent_state) + return latent_state + + def _prediction(self, latent_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Use the representation network to encode the observations into latent state. + Arguments: + - obs (:obj:`torch.Tensor`): The 1D vector observation data. + Returns: + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + Shapes: + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. + """ + policy_logits, value = self.prediction_network(latent_state) + return policy_logits, value + + def _dynamics(self, latent_state: torch.Tensor, action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Concatenate ``latent_state`` and ``action`` and use the dynamics network to predict ``next_latent_state`` + ``reward`` and ``next_reward_hidden_state``. + Arguments: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The input hidden state of LSTM about reward. + - action (:obj:`torch.Tensor`): The predicted action to rollout. + Returns: + - next_latent_state (:obj:`torch.Tensor`): The predicted latent state of the next timestep. + - next_reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The output hidden state of LSTM about reward. + - reward (:obj:`torch.Tensor`): The predicted reward for input state. + Shapes: + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size. + - next_latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + """ + # NOTE: the discrete action encoding type is important for some environments + + # discrete action space + if self.discrete_action_encoding_type == 'one_hot': + # Stack latent_state with the one hot encoded action + if len(action.shape) == 1: + # (batch_size, ) -> (batch_size, 1) + # e.g., torch.Size([8]) -> torch.Size([8, 1]) + action = action.unsqueeze(-1) + + # transform action to one-hot encoding. + # action_one_hot shape: (batch_size, action_space_size), e.g., (8, 4) + action_one_hot = torch.zeros(action.shape[0], self.action_space_size, device=action.device) + # transform action to torch.int64 + action = action.long() + action_one_hot.scatter_(1, action, 1) + action_encoding = action_one_hot + elif self.discrete_action_encoding_type == 'not_one_hot': + action_encoding = action / self.action_space_size + if len(action_encoding.shape) == 1: + # (batch_size, ) -> (batch_size, 1) + # e.g., torch.Size([8]) -> torch.Size([8, 1]) + action_encoding = action_encoding.unsqueeze(-1) + + action_encoding = action_encoding.to(latent_state.device).float() + # state_action_encoding shape: (batch_size, latent_state[1] + action_dim]) or + # (batch_size, latent_state[1] + action_space_size]) depending on the discrete_action_encoding_type. + state_action_encoding = torch.cat((latent_state, action_encoding), dim=1) + + next_latent_state, reward = self.dynamics_network(state_action_encoding) + + if not self.state_norm: + return next_latent_state, reward + else: + next_latent_state_normalized = renormalize(next_latent_state) + return next_latent_state_normalized, reward + + def project(self, latent_state: torch.Tensor, with_grad=True) -> torch.Tensor: + """ + Overview: + Project the latent state to a lower dimension to calculate the self-supervised loss, which is proposed in EfficientZero. + For more details, please refer to the paper ``Exploring Simple Siamese Representation Learning``. + Arguments: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - with_grad (:obj:`bool`): Whether to calculate gradient for the projection result. + Returns: + - proj (:obj:`torch.Tensor`): The result embedding vector of projection operation. + Shapes: + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - proj (:obj:`torch.Tensor`): :math:`(B, projection_output_dim)`, where B is batch_size. + + Examples: + >>> latent_state = torch.randn(256, 64) + >>> output = self.project(latent_state) + >>> output.shape # (256, 1024) + """ + proj = self.projection(latent_state) + + if with_grad: + # with grad, use prediction_head + return self.prediction_head(proj) + else: + return proj.detach() + + def get_params_mean(self) -> float: + return get_params_mean(self) + + +class DynamicsNetwork(nn.Module): + + def __init__( + self, + action_encoding_dim: int = 2, + num_channels: int = 64, + common_layer_num: int = 2, + fc_reward_layers: SequenceType = [32], + output_support_size: int = 601, + last_linear_layer_init_zero: bool = True, + activation: Optional[nn.Module] = nn.ReLU(inplace=True), + norm_type: Optional[str] = 'BN', + res_connection_in_dynamics: bool = False, + ): + """ + Overview: + The definition of dynamics network in MuZero algorithm, which is used to predict next latent state + reward by the given current latent state and action. + The networks are mainly built on fully connected layers. + Arguments: + - action_encoding_dim (:obj:`int`): The dimension of action encoding. + - num_channels (:obj:`int`): The num of channels in latent states. + - common_layer_num (:obj:`int`): The number of common layers in dynamics network. + - fc_reward_layers (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). + - output_support_size (:obj:`int`): The size of categorical reward output. + - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of value/policy mlp, default sets it to True. + - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ + operation to speedup, e.g. ReLU(inplace=True). + - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + - res_connection_in_dynamics (:obj:`bool`): Whether to use residual connection in dynamics network. + """ + super().__init__() + self.num_channels = num_channels + self.action_encoding_dim = action_encoding_dim + self.latent_state_dim = self.num_channels - self.action_encoding_dim + + self.res_connection_in_dynamics = res_connection_in_dynamics + if self.res_connection_in_dynamics: + self.fc_dynamics_1 = MLP( + in_channels=self.num_channels, + hidden_channels=self.latent_state_dim, + layer_num=common_layer_num, + out_channels=self.latent_state_dim, + activation=activation, + norm_type=norm_type, + output_activation=True, + output_norm=True, + # last_linear_layer_init_zero=False is important for convergence + last_linear_layer_init_zero=False, + ) + self.fc_dynamics_2 = MLP( + in_channels=self.latent_state_dim, + hidden_channels=self.latent_state_dim, + layer_num=common_layer_num, + out_channels=self.latent_state_dim, + activation=activation, + norm_type=norm_type, + output_activation=True, + output_norm=True, + # last_linear_layer_init_zero=False is important for convergence + last_linear_layer_init_zero=False, + ) + else: + self.fc_dynamics = MLP( + in_channels=self.num_channels, + hidden_channels=self.latent_state_dim, + layer_num=common_layer_num, + out_channels=self.latent_state_dim, + activation=activation, + norm_type=norm_type, + output_activation=True, + output_norm=True, + # last_linear_layer_init_zero=False is important for convergence + last_linear_layer_init_zero=False, + ) + + self.fc_reward_head = MLP( + in_channels=self.latent_state_dim, + hidden_channels=fc_reward_layers[0], + layer_num=2, + out_channels=output_support_size, + activation=activation, + norm_type=norm_type, + output_activation=False, + output_norm=False, + last_linear_layer_init_zero=last_linear_layer_init_zero + ) + + def forward(self, state_action_encoding: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Forward computation of the dynamics network. Predict the next latent state given current latent state and action. + Arguments: + - state_action_encoding (:obj:`torch.Tensor`): The state-action encoding, which is the concatenation of \ + latent state and action encoding, with shape (batch_size, num_channels, height, width). + Returns: + - next_latent_state (:obj:`torch.Tensor`): The next latent state, with shape (batch_size, latent_state_dim). + - reward (:obj:`torch.Tensor`): The predicted reward for input state. + """ + if self.res_connection_in_dynamics: + # take the state encoding (e.g. latent_state), + # state_action_encoding[:, -self.action_encoding_dim:] is action encoding + latent_state = state_action_encoding[:, :-self.action_encoding_dim] + x = self.fc_dynamics_1(state_action_encoding) + # the residual link: add the latent_state to the state_action encoding + next_latent_state = x + latent_state + next_latent_state_encoding = self.fc_dynamics_2(next_latent_state) + else: + next_latent_state = self.fc_dynamics(state_action_encoding) + next_latent_state_encoding = next_latent_state + + reward = self.fc_reward_head(next_latent_state_encoding) + + return next_latent_state, reward + + def get_dynamic_mean(self) -> float: + return get_dynamic_mean(self) + + def get_reward_mean(self) -> float: + return get_reward_mean(self) + + +class PredictionNetworkMD(nn.Module): + + def __init__( + self, + agent_num: int, + single_agent_action_size, + num_channels, + common_layer_num: int = 2, + fc_value_layers: SequenceType = [32], + output_support_size: int = 601, + last_linear_layer_init_zero: bool = True, + activation: Optional[nn.Module] = nn.ReLU(inplace=True), + norm_type: Optional[str] = 'BN', + output_separate_logit: bool = False, + ): + """ + Overview: + The definition of policy and value prediction network with Multi-Layer Perceptron (MLP), + which is used to predict value and policy by the given latent state. Policy network is a multihead network, + which predicts the policy for each agent. + Arguments: + - agent_num (:obj:`int`): The number of agents in the environment. + - single_agent_action_size: (:obj:`int`): Action space size for single agent. + - num_channels (:obj:`int`): The channels of latent states. + - fc_value_layers (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). + - output_support_size (:obj:`int`): The size of categorical value output. + - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of \ + dynamics/prediction mlp, default sets it to True. + - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ + operation to speedup, e.g. ReLU(inplace=True). + - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + """ + super().__init__() + self.num_channels = num_channels + + # ******* common backbone ****** + self.fc_prediction_common = MLP( + in_channels=self.num_channels, + hidden_channels=self.num_channels, + out_channels=self.num_channels, + layer_num=common_layer_num, + activation=activation, + norm_type=norm_type, + output_activation=True, + output_norm=True, + # last_linear_layer_init_zero=False is important for convergence + last_linear_layer_init_zero=False, + ) + + # ******* value and policy head ****** + self.fc_value_head = MLP( + in_channels=self.num_channels, + hidden_channels=fc_value_layers[0], + out_channels=output_support_size, + layer_num=len(fc_value_layers) + 1, + activation=activation, + norm_type=norm_type, + output_activation=False, + output_norm=False, + # last_linear_layer_init_zero=True is beneficial for convergence speed. + last_linear_layer_init_zero=last_linear_layer_init_zero + ) + self.policy_multi_head = MultiHead( + head_cls=DiscreteHead, + hidden_size=self.num_channels, + output_size_list=[single_agent_action_size for _ in range(agent_num)], + ) + self.output_separate_logit = output_separate_logit + + def forward(self, latent_state: torch.Tensor): + """ + Overview: + Forward computation of the prediction network. + Arguments: + - latent_state (:obj:`torch.Tensor`): input tensor with shape (B, latent_state_dim). + Returns: + - policy (:obj:`torch.Tensor`): policy tensor with shape (B, action_space_size). + - value (:obj:`torch.Tensor`): value tensor with shape (B, output_support_size). + """ + latent_state = latent_state.to(torch.float32) + x_prediction_common = self.fc_prediction_common(latent_state) + + value = self.fc_value_head(x_prediction_common) + # policy_list: {'logit': [policy1, policy2, ...],} + # policyi shape: (B, action_space_size) + policy_list = self.policy_multi_head(x_prediction_common)['logit'] + if not self.output_separate_logit: + # The joint action space policy is the product of each agent policy + # policy shape: (B, action_space_size^^agent_num) + batch_size = latent_state.size(0) + joint_logits_batches = [] + for i in range(batch_size): + current_batch = [policy[i] for policy in policy_list] + cartesian_prod_result = torch.cartesian_prod(*current_batch) + joint_logits = cartesian_prod_result.prod(dim=1) + joint_logits_batches.append(joint_logits) + policy = torch.stack(joint_logits_batches) + else: + # policy_list: [policy1, policy2, ...] + # policy sahpe: (B, agent_num, action_space_size) + policy = torch.stack(policy_list, dim=1) + return policy, value diff --git a/lzero/model/sampled_efficientzero_model_gcn.py b/lzero/model/sampled_efficientzero_model_gcn.py new file mode 100644 index 000000000..70fe8bf05 --- /dev/null +++ b/lzero/model/sampled_efficientzero_model_gcn.py @@ -0,0 +1,534 @@ +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from ding.model.common import ReparameterizationHead +from ding.torch_utils import MLP +from ding.utils import MODEL_REGISTRY, SequenceType + +from .common import EZNetworkOutput +from .common_gcn import RepresentationNetworkGCN +from .efficientzero_model_mlp import DynamicsNetworkMLP +from .utils import renormalize, get_params_mean + + +@MODEL_REGISTRY.register('SampledEfficientZeroModelMLP') +class SampledEfficientZeroModelMLP(nn.Module): + + def __init__( + self, + robot_state_dim: int = 10, + human_state_dim: int = 10, + robot_num: int = 5, + human_num: int = 5, + action_space_size: int = 6, + latent_state_dim: int = 256, + lstm_hidden_size: int = 512, + fc_reward_layers: SequenceType = [32], + fc_value_layers: SequenceType = [32], + fc_policy_layers: SequenceType = [32], + reward_support_size: int = 601, + value_support_size: int = 601, + proj_hid: int = 1024, + proj_out: int = 1024, + pred_hid: int = 512, + pred_out: int = 1024, + self_supervised_learning_loss: bool = True, + categorical_distribution: bool = True, + activation: Optional[nn.Module] = nn.ReLU(inplace=True), + last_linear_layer_init_zero: bool = True, + state_norm: bool = False, + # ============================================================== + # specific sampled related config + # ============================================================== + continuous_action_space: bool = False, + num_of_sampled_actions: int = 6, + sigma_type='conditioned', + fixed_sigma_value: float = 0.3, + bound_type: str = None, + norm_type: str = 'BN', + discrete_action_encoding_type: str = 'one_hot', + res_connection_in_dynamics: bool = False, + *args, + **kwargs, + ): + """ + Overview: + The definition of the network model of Sampled EfficientZero, which is a generalization version for 1D vector obs. + The networks are mainly built on fully connected layers. + Sampled EfficientZero model consists of a representation network, a dynamics network and a prediction network. + The representation network is an MLP network which maps the raw observation to a latent state. + The dynamics network is an MLP+LSTM network which predicts the next latent state, reward_hidden_state and value_prefix given the current latent state and action. + The prediction network is an MLP network which predicts the value and policy given the current latent state. + Arguments: + - observation_shape (:obj:`int`): Observation space shape, e.g. 8 for Lunarlander. + - action_space_size: (:obj:`int`): Action space size, which is an integer number. For discrete action space, it is the num of discrete actions, \ + e.g. 4 for Lunarlander. For continuous action space, it is the dimension of the continuous action, e.g. 4 for bipedalwalker. + - latent_state_dim (:obj:`int`): The dimension of latent state, such as 256. + - lstm_hidden_size (:obj:`int`): The hidden size of LSTM in dynamics network to predict value_prefix. + - fc_reward_layers (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). + - fc_value_layers (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). + - fc_policy_layers (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). + - reward_support_size (:obj:`int`): The size of categorical reward output + - value_support_size (:obj:`int`): The size of categorical value output. + - proj_hid (:obj:`int`): The size of projection hidden layer. + - proj_out (:obj:`int`): The size of projection output layer. + - pred_hid (:obj:`int`): The size of prediction hidden layer. + - pred_out (:obj:`int`): The size of prediction output layer. + - self_supervised_learning_loss (:obj:`bool`): Whether to use self_supervised_learning related networks in Sampled EfficientZero model, default set it to False. + - categorical_distribution (:obj:`bool`): Whether to use discrete support to represent categorical distribution for value, reward/value_prefix. + - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ + operation to speedup, e.g. ReLU(inplace=True). + - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of value/policy mlp, default sets it to True. + - state_norm (:obj:`bool`): Whether to use normalization for latent states, default sets it to True. + # ============================================================== + # specific sampled related config + # ============================================================== + - continuous_action_space (:obj:`bool`): The type of action space. default set it to False. + - num_of_sampled_actions (:obj:`int`): the number of sampled actions, i.e. the K in original Sampled MuZero paper. + # see ``ReparameterizationHead`` in ``ding.model.common.head`` for more details about the following arguments. + - sigma_type (:obj:`str`): the type of sigma in policy head of prediction network, options={'conditioned', 'fixed'}. + - fixed_sigma_value (:obj:`float`): the fixed sigma value in policy head of prediction network, + - bound_type (:obj:`str`): The type of bound in networks. Default sets it to None. + - norm_type (:obj:`str`): The type of normalization in networks. default set it to 'BN'. + - discrete_action_encoding_type (:obj:`str`): The type of encoding for discrete action. Default sets it to 'one_hot'. options = {'one_hot', 'not_one_hot'} + - res_connection_in_dynamics (:obj:`bool`): Whether to use residual connection for dynamics network, default set it to False. + """ + super(SampledEfficientZeroModelMLP, self).__init__() + if not categorical_distribution: + self.reward_support_size = 1 + self.value_support_size = 1 + else: + self.reward_support_size = reward_support_size + self.value_support_size = value_support_size + + self.continuous_action_space = continuous_action_space + # self.observation_shape = observation_shape + self.robot_state_dim = robot_state_dim + self.human_state_dim = human_state_dim + self.robot_num = robot_num + self.human_num = human_num + self.action_space_size = action_space_size + # The dim of action space. For discrete action space, it is 1. + # For continuous action space, it is the dimension of continuous action. + self.action_space_dim = action_space_size if self.continuous_action_space else 1 + assert discrete_action_encoding_type in ['one_hot', 'not_one_hot'], discrete_action_encoding_type + self.discrete_action_encoding_type = discrete_action_encoding_type + if self.continuous_action_space: + self.action_encoding_dim = action_space_size + else: + if self.discrete_action_encoding_type == 'one_hot': + self.action_encoding_dim = action_space_size + elif self.discrete_action_encoding_type == 'not_one_hot': + self.action_encoding_dim = 1 + + self.lstm_hidden_size = lstm_hidden_size + self.latent_state_dim = latent_state_dim + self.fc_reward_layers = fc_reward_layers + self.fc_value_layers = fc_value_layers + self.fc_policy_layers = fc_policy_layers + self.proj_hid = proj_hid + self.proj_out = proj_out + self.pred_hid = pred_hid + self.pred_out = pred_out + + self.last_linear_layer_init_zero = last_linear_layer_init_zero + self.state_norm = state_norm + self.self_supervised_learning_loss = self_supervised_learning_loss + + self.sigma_type = sigma_type + self.fixed_sigma_value = fixed_sigma_value + self.bound_type = bound_type + self.norm_type = norm_type + self.num_of_sampled_actions = num_of_sampled_actions + self.res_connection_in_dynamics = res_connection_in_dynamics + + self.representation_network = RepresentationNetworkGCN( + robot_state_dim=robot_state_dim, + human_state_dim=human_state_dim, + robot_num=robot_num, + human_num=human_num, + hidden_channels=self.latent_state_dim, + norm_type=norm_type + ) + + self.dynamics_network = DynamicsNetworkMLP( + action_encoding_dim=self.action_encoding_dim, + num_channels=self.latent_state_dim + self.action_encoding_dim, + common_layer_num=2, + lstm_hidden_size=self.lstm_hidden_size, + fc_reward_layers=self.fc_reward_layers, + output_support_size=self.reward_support_size, + last_linear_layer_init_zero=self.last_linear_layer_init_zero, + norm_type=norm_type, + res_connection_in_dynamics=self.res_connection_in_dynamics, + ) + + self.prediction_network = PredictionNetworkMLP( + continuous_action_space=self.continuous_action_space, + action_space_size=self.action_space_size, + num_channels=self.latent_state_dim, + fc_value_layers=self.fc_value_layers, + fc_policy_layers=self.fc_policy_layers, + output_support_size=self.value_support_size, + last_linear_layer_init_zero=self.last_linear_layer_init_zero, + sigma_type=self.sigma_type, + fixed_sigma_value=self.fixed_sigma_value, + bound_type=self.bound_type, + norm_type=self.norm_type, + ) + + if self.self_supervised_learning_loss: + # self_supervised_learning_loss related network proposed in EfficientZero + self.projection_input_dim = latent_state_dim + self.projection = nn.Sequential( + nn.Linear(self.projection_input_dim, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, + nn.Linear(self.proj_hid, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, + nn.Linear(self.proj_hid, self.proj_out), nn.BatchNorm1d(self.proj_out) + ) + self.prediction_head = nn.Sequential( + nn.Linear(self.proj_out, self.pred_hid), + nn.BatchNorm1d(self.pred_hid), + activation, + nn.Linear(self.pred_hid, self.pred_out), + ) + + def initial_inference(self, obs: torch.Tensor) -> EZNetworkOutput: + """ + Overview: + Initial inference of SampledEfficientZero model, which is the first step of the SampledEfficientZero model. + To perform the initial inference, we first use the representation network to obtain the "latent_state" of the observation. + Then we use the prediction network to predict the "value" and "policy_logits" of the "latent_state", and + also prepare the zeros-like ``reward_hidden_state`` for the next step of the Sampled EfficientZero model. + Arguments: + - obs (:obj:`torch.Tensor`): The 1D vector observation data. + Returns (EZNetworkOutput): + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + - value_prefix (:obj:`torch.Tensor`): The predicted prefix sum of value for input state. \ + In initial inference, we set it to zero vector. + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The hidden state of LSTM about reward. In initial inference, \ + we set it to the zeros-like hidden state (H and C). + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, obs_shape)`, where B is batch_size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. + - value_prefix (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The shape of each element is :math:`(1, B, lstm_hidden_size)`, where B is batch_size. + """ + batch_size = obs.size(0) + latent_state = self._representation(obs) + policy_logits, value = self._prediction(latent_state) + # zero initialization for reward hidden states + # (hn, cn), each element shape is (layer_num=1, batch_size, lstm_hidden_size) + reward_hidden_state = ( + torch.zeros(1, batch_size, + self.lstm_hidden_size).to(obs.device), torch.zeros(1, batch_size, + self.lstm_hidden_size).to(obs.device) + ) + return EZNetworkOutput(value, [0. for _ in range(batch_size)], policy_logits, latent_state, reward_hidden_state) + + def recurrent_inference( + self, latent_state: torch.Tensor, reward_hidden_state: torch.Tensor, action: torch.Tensor + ) -> EZNetworkOutput: + """ + Overview: + Recurrent inference of Sampled EfficientZero model, which is the rollout step of the Sampled EfficientZero model. + To perform the recurrent inference, we first use the dynamics network to predict ``next_latent_state``, + ``reward_hidden_state``, ``value_prefix`` by the given current ``latent_state`` and ``action``. + We then use the prediction network to predict the ``value`` and ``policy_logits``. + Arguments: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The input hidden state of LSTM about reward. + - action (:obj:`torch.Tensor`): The predicted action to rollout. + Returns (EZNetworkOutput): + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + - value_prefix (:obj:`torch.Tensor`): The predicted prefix sum of value for input state. + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - next_latent_state (:obj:`torch.Tensor`): The predicted next latent state. + - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The output hidden state of LSTM about reward. + Shapes: + - action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. + - value_prefix (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - next_latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The shape of each element is :math:`(1, B, lstm_hidden_size)`, where B is batch_size. + """ + next_latent_state, reward_hidden_state, value_prefix = self._dynamics(latent_state, reward_hidden_state, action) + policy_logits, value = self._prediction(next_latent_state) + return EZNetworkOutput(value, value_prefix, policy_logits, next_latent_state, reward_hidden_state) + + def _representation(self, observation: torch.Tensor) -> Tuple[torch.Tensor]: + """ + Overview: + Use the representation network to encode the observations into latent state. + Arguments: + - obs (:obj:`torch.Tensor`): The 1D vector observation data. + Returns: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, obs_shape)`, where B is batch_size. + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + """ + latent_state = self.representation_network(observation) + if self.state_norm: + latent_state = renormalize(latent_state) + return latent_state + + def _prediction(self, latent_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Use the representation network to encode the observations into latent state. + Arguments: + - obs (:obj:`torch.Tensor`): The 1D vector observation data. + Returns: + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + Shapes: + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. + """ + policy, value = self.prediction_network(latent_state) + return policy, value + + def _dynamics(self, latent_state: torch.Tensor, reward_hidden_state: Tuple, + action: torch.Tensor) -> Tuple[torch.Tensor, Tuple[torch.Tensor], torch.Tensor]: + """ + Overview: + Concatenate ``latent_state`` and ``action`` and use the dynamics network to predict ``next_latent_state`` + ``value_prefix`` and ``next_reward_hidden_state``. + Arguments: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The input hidden state of LSTM about reward. + - action (:obj:`torch.Tensor`): The predicted action to rollout. + Returns: + - next_latent_state (:obj:`torch.Tensor`): The predicted latent state of the next timestep. + - next_reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The output hidden state of LSTM about reward. + - value_prefix (:obj:`torch.Tensor`): The predicted prefix sum of value for input state. + Shapes: + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size. + - next_latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - value_prefix (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + """ + # NOTE: the discrete action encoding type is important for some environments + + if not self.continuous_action_space: + # discrete action space + if self.discrete_action_encoding_type == 'one_hot': + # Stack latent_state with the one hot encoded action + if len(action.shape) == 1: + # (batch_size, ) -> (batch_size, 1) + # e.g., torch.Size([8]) -> torch.Size([8, 1]) + action = action.unsqueeze(-1) + + # transform action to one-hot encoding. + # action_one_hot shape: (batch_size, action_space_size), e.g., (8, 4) + action_one_hot = torch.zeros(action.shape[0], self.action_space_size, device=action.device) + # transform action to torch.int64 + action = action.long() + action_one_hot.scatter_(1, action, 1) + action_encoding = action_one_hot + elif self.discrete_action_encoding_type == 'not_one_hot': + action_encoding = action / self.action_space_size + if len(action_encoding.shape) == 1: + # (batch_size, ) -> (batch_size, 1) + # e.g., torch.Size([8]) -> torch.Size([8, 1]) + action_encoding = action_encoding.unsqueeze(-1) + else: + # continuous action space + if len(action.shape) == 1: + # (batch_size, ) -> (batch_size, 1) + # e.g., torch.Size([8]) -> torch.Size([8, 1]) + action = action.unsqueeze(-1) + elif len(action.shape) == 3: + # (batch_size, action_dim, 1) -> (batch_size, action_dim) + # e.g., torch.Size([8, 2, 1]) -> torch.Size([8, 2]) + action = action.squeeze(-1) + + action_encoding = action + + action_encoding = action_encoding.to(latent_state.device).float() + # state_action_encoding shape: (batch_size, latent_state[1] + action_dim]) or + # (batch_size, latent_state[1] + action_space_size]) depending on the discrete_action_encoding_type. + state_action_encoding = torch.cat((latent_state, action_encoding), dim=1) + + next_latent_state, next_reward_hidden_state, value_prefix = self.dynamics_network( + state_action_encoding, reward_hidden_state + ) + + if not self.state_norm: + return next_latent_state, next_reward_hidden_state, value_prefix + else: + next_latent_state_normalized = renormalize(next_latent_state) + return next_latent_state_normalized, next_reward_hidden_state, value_prefix + + def project(self, latent_state: torch.Tensor, with_grad=True) -> torch.Tensor: + """ + Overview: + Project the latent state to a lower dimension to calculate the self-supervised loss, which is proposed in EfficientZero. + For more details, please refer to the paper ``Exploring Simple Siamese Representation Learning``. + Arguments: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - with_grad (:obj:`bool`): Whether to calculate gradient for the projection result. + Returns: + - proj (:obj:`torch.Tensor`): The result embedding vector of projection operation. + Shapes: + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - proj (:obj:`torch.Tensor`): :math:`(B, projection_output_dim)`, where B is batch_size. + + Examples: + >>> latent_state = torch.randn(256, 64) + >>> output = self.project(latent_state) + >>> output.shape # (256, 1024) + """ + proj = self.projection(latent_state) + + if with_grad: + # with grad, use prediction_head + return self.prediction_head(proj) + else: + return proj.detach() + + def get_params_mean(self): + return get_params_mean(self) + + +class PredictionNetworkMLP(nn.Module): + + def __init__( + self, + continuous_action_space, + action_space_size, + num_channels, + common_layer_num: int = 2, + fc_value_layers: SequenceType = [32], + fc_policy_layers: SequenceType = [32], + output_support_size: int = 601, + last_linear_layer_init_zero: bool = True, + activation: Optional[nn.Module] = nn.ReLU(inplace=True), + # ============================================================== + # specific sampled related config + # ============================================================== + sigma_type='conditioned', + fixed_sigma_value: float = 0.3, + bound_type: str = None, + norm_type: str = 'BN', + ): + """ + Overview: + The definition of policy and value prediction network, which is used to predict value and policy by the + given latent state. + The networks are mainly built on fully connected layers. + Arguments: + - continuous_action_space (:obj:`bool`): The type of action space. default set it to False. + - action_space_size: (:obj:`int`): Action space size, usually an integer number. For discrete action \ + space, it is the number of discrete actions. For continuous action space, it is the dimension of \ + continuous action. + - num_channels (:obj:`int`): The num of channels in latent states. + - num_res_blocks (:obj:`int`): The number of res blocks. + - fc_value_layers (:obj:`SequenceType`): hidden layers of the value prediction head (MLP head). + - fc_policy_layers (:obj:`SequenceType`): hidden layers of the policy prediction head (MLP head). + - output_support_size (:obj:`int`): dim of value output. + - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of value/policy mlp, default sets it to True. + # ============================================================== + # specific sampled related config + # ============================================================== + # see ``ReparameterizationHead`` in ``ding.model.common.head`` for more details about thee following arguments. + - sigma_type (:obj:`str`): the type of sigma in policy head of prediction network, options={'conditioned', 'fixed'}. + - fixed_sigma_value (:obj:`float`): the fixed sigma value in policy head of prediction network, + - bound_type (:obj:`str`): The type of bound in networks. default set it to None. + - norm_type (:obj:`str`): The type of normalization in networks. default set it to 'BN'. + """ + super().__init__() + self.num_channels = num_channels + self.continuous_action_space = continuous_action_space + self.norm_type = norm_type + self.sigma_type = sigma_type + self.fixed_sigma_value = fixed_sigma_value + self.bound_type = bound_type + self.action_space_size = action_space_size + if self.continuous_action_space: + self.action_encoding_dim = self.action_space_size + else: + self.action_encoding_dim = 1 + + # ******* common backbone ****** + self.fc_prediction_common = MLP( + in_channels=self.num_channels, + hidden_channels=self.num_channels, + out_channels=self.num_channels, + layer_num=common_layer_num, + activation=activation, + norm_type=norm_type, + output_activation=True, + output_norm=True, + # last_linear_layer_init_zero=False is important for convergence + last_linear_layer_init_zero=False, + ) + + # ******* value and policy head ****** + self.fc_value_head = MLP( + in_channels=self.num_channels, + hidden_channels=fc_value_layers[0], + out_channels=output_support_size, + layer_num=2, + activation=activation, + norm_type=norm_type, + output_activation=False, + output_norm=False, + # last_linear_layer_init_zero=True is beneficial for convergence speed. + last_linear_layer_init_zero=last_linear_layer_init_zero + ) + + # sampled related core code + if self.continuous_action_space: + self.fc_policy_head = ReparameterizationHead( + input_size=self.num_channels, + output_size=action_space_size, + layer_num=2, + sigma_type=self.sigma_type, + fixed_sigma_value=self.fixed_sigma_value, + activation=nn.ReLU(), + norm_type=None, + bound_type=self.bound_type + ) + else: + self.fc_policy_head = MLP( + in_channels=self.num_channels, + hidden_channels=fc_policy_layers[0], + out_channels=action_space_size, + layer_num=2, + activation=activation, + norm_type=self.norm_type, + output_activation=False, + output_norm=False, + # last_linear_layer_init_zero=True is beneficial for convergence speed. + last_linear_layer_init_zero=last_linear_layer_init_zero + ) + + def forward(self, latent_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Forward computation of the prediction network. + Arguments: + - latent_state (:obj:`torch.Tensor`): input tensor with shape (B, in_channels). + Returns: + - policy (:obj:`torch.Tensor`): policy tensor. If action space is discrete, shape is (B, action_space_size). + If action space is continuous, shape is (B, action_space_size * 2). + - value (:obj:`torch.Tensor`): value tensor with shape (B, output_support_size). + """ + x_prediction_common = self.fc_prediction_common(latent_state) + value = self.fc_value_head(x_prediction_common) + + # sampled related core code + policy = self.fc_policy_head(x_prediction_common) + if self.continuous_action_space: + policy = torch.cat([policy['mu'], policy['sigma']], dim=-1) + + return policy, value diff --git a/lzero/model/sampled_efficientzero_model_md.py b/lzero/model/sampled_efficientzero_model_md.py new file mode 100644 index 000000000..ec2673a3a --- /dev/null +++ b/lzero/model/sampled_efficientzero_model_md.py @@ -0,0 +1,547 @@ +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from ding.model.common import ReparameterizationHead +from ding.model.common.head import MultiHead, DiscreteHead +from ding.torch_utils import MLP +from ding.utils import MODEL_REGISTRY, SequenceType + +from .common import EZNetworkOutput, RepresentationNetworkMLP +from .efficientzero_model_mlp import DynamicsNetworkMLP +from .utils import renormalize, get_params_mean + + +@MODEL_REGISTRY.register('SampledEfficientZeroModelMD') +class SampledEfficientZeroModelMD(nn.Module): + + def __init__( + self, + agent_num: int, + single_agent_action_size: int, + output_separate_logit: bool = False, + observation_shape: int = 2, + action_space_size: int = 6, + latent_state_dim: int = 256, + lstm_hidden_size: int = 512, + fc_reward_layers: SequenceType = [32], + fc_value_layers: SequenceType = [32], + fc_policy_layers: SequenceType = [32], + reward_support_size: int = 601, + value_support_size: int = 601, + proj_hid: int = 1024, + proj_out: int = 1024, + pred_hid: int = 512, + pred_out: int = 1024, + self_supervised_learning_loss: bool = True, + categorical_distribution: bool = True, + activation: Optional[nn.Module] = nn.ReLU(inplace=True), + last_linear_layer_init_zero: bool = True, + state_norm: bool = False, + # ============================================================== + # specific sampled related config + # ============================================================== + continuous_action_space: bool = False, + num_of_sampled_actions: int = 6, + sigma_type='conditioned', + fixed_sigma_value: float = 0.3, + bound_type: str = None, + norm_type: str = 'BN', + discrete_action_encoding_type: str = 'one_hot', + res_connection_in_dynamics: bool = False, + *args, + **kwargs, + ): + """ + Overview: + The definition of the network model of Sampled EfficientZero, which is a generalization version for 1D vector obs. + The networks are mainly built on fully connected layers. + Sampled EfficientZero model consists of a representation network, a dynamics network and a prediction network. + The representation network is an MLP network which maps the raw observation to a latent state. + The dynamics network is an MLP+LSTM network which predicts the next latent state, reward_hidden_state and value_prefix given the current latent state and action. + The prediction network is an MLP network which predicts the value and policy given the current latent state. + Arguments: + - observation_shape (:obj:`int`): Observation space shape, e.g. 8 for Lunarlander. + - action_space_size: (:obj:`int`): Action space size, which is an integer number. For discrete action space, it is the num of discrete actions, \ + e.g. 4 for Lunarlander. For continuous action space, it is the dimension of the continuous action, e.g. 4 for bipedalwalker. + - latent_state_dim (:obj:`int`): The dimension of latent state, such as 256. + - lstm_hidden_size (:obj:`int`): The hidden size of LSTM in dynamics network to predict value_prefix. + - fc_reward_layers (:obj:`SequenceType`): The number of hidden layers of the reward head (MLP head). + - fc_value_layers (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). + - fc_policy_layers (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). + - reward_support_size (:obj:`int`): The size of categorical reward output + - value_support_size (:obj:`int`): The size of categorical value output. + - proj_hid (:obj:`int`): The size of projection hidden layer. + - proj_out (:obj:`int`): The size of projection output layer. + - pred_hid (:obj:`int`): The size of prediction hidden layer. + - pred_out (:obj:`int`): The size of prediction output layer. + - self_supervised_learning_loss (:obj:`bool`): Whether to use self_supervised_learning related networks in Sampled EfficientZero model, default set it to False. + - categorical_distribution (:obj:`bool`): Whether to use discrete support to represent categorical distribution for value, reward/value_prefix. + - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ + operation to speedup, e.g. ReLU(inplace=True). + - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of value/policy mlp, default sets it to True. + - state_norm (:obj:`bool`): Whether to use normalization for latent states, default sets it to True. + # ============================================================== + # specific sampled related config + # ============================================================== + - continuous_action_space (:obj:`bool`): The type of action space. default set it to False. + - num_of_sampled_actions (:obj:`int`): the number of sampled actions, i.e. the K in original Sampled MuZero paper. + # see ``ReparameterizationHead`` in ``ding.model.common.head`` for more details about the following arguments. + - sigma_type (:obj:`str`): the type of sigma in policy head of prediction network, options={'conditioned', 'fixed'}. + - fixed_sigma_value (:obj:`float`): the fixed sigma value in policy head of prediction network, + - bound_type (:obj:`str`): The type of bound in networks. Default sets it to None. + - norm_type (:obj:`str`): The type of normalization in networks. default set it to 'BN'. + - discrete_action_encoding_type (:obj:`str`): The type of encoding for discrete action. Default sets it to 'one_hot'. options = {'one_hot', 'not_one_hot'} + - res_connection_in_dynamics (:obj:`bool`): Whether to use residual connection for dynamics network, default set it to False. + """ + super(SampledEfficientZeroModelMD, self).__init__() + if not categorical_distribution: + self.reward_support_size = 1 + self.value_support_size = 1 + else: + self.reward_support_size = reward_support_size + self.value_support_size = value_support_size + + self.continuous_action_space = continuous_action_space + self.observation_shape = observation_shape + self.action_space_size = action_space_size + # The dim of action space. For discrete action space, it is 1. + # For continuous action space, it is the dimension of continuous action. + self.action_space_dim = action_space_size if self.continuous_action_space else 1 + assert discrete_action_encoding_type in ['one_hot', 'not_one_hot'], discrete_action_encoding_type + self.discrete_action_encoding_type = discrete_action_encoding_type + if self.continuous_action_space: + self.action_encoding_dim = action_space_size + else: + if self.discrete_action_encoding_type == 'one_hot': + self.action_encoding_dim = action_space_size + elif self.discrete_action_encoding_type == 'not_one_hot': + self.action_encoding_dim = 1 + + self.lstm_hidden_size = lstm_hidden_size + self.latent_state_dim = latent_state_dim + self.fc_reward_layers = fc_reward_layers + self.fc_value_layers = fc_value_layers + self.fc_policy_layers = fc_policy_layers + self.proj_hid = proj_hid + self.proj_out = proj_out + self.pred_hid = pred_hid + self.pred_out = pred_out + + self.last_linear_layer_init_zero = last_linear_layer_init_zero + self.state_norm = state_norm + self.self_supervised_learning_loss = self_supervised_learning_loss + + self.sigma_type = sigma_type + self.fixed_sigma_value = fixed_sigma_value + self.bound_type = bound_type + self.norm_type = norm_type + self.num_of_sampled_actions = num_of_sampled_actions + self.res_connection_in_dynamics = res_connection_in_dynamics + + self.representation_network = RepresentationNetworkMLP( + observation_shape=self.observation_shape, hidden_channels=self.latent_state_dim, norm_type=norm_type + ) + + self.dynamics_network = DynamicsNetworkMLP( + action_encoding_dim=self.action_encoding_dim, + num_channels=self.latent_state_dim + self.action_encoding_dim, + common_layer_num=2, + lstm_hidden_size=self.lstm_hidden_size, + fc_reward_layers=self.fc_reward_layers, + output_support_size=self.reward_support_size, + last_linear_layer_init_zero=self.last_linear_layer_init_zero, + norm_type=norm_type, + res_connection_in_dynamics=self.res_connection_in_dynamics, + ) + + self.prediction_network = PredictionNetworkMD( + agent_num=agent_num, + single_agent_action_size=single_agent_action_size, + continuous_action_space=self.continuous_action_space, + action_space_size=self.action_space_size, + num_channels=self.latent_state_dim, + fc_value_layers=self.fc_value_layers, + output_support_size=self.value_support_size, + last_linear_layer_init_zero=self.last_linear_layer_init_zero, + sigma_type=self.sigma_type, + fixed_sigma_value=self.fixed_sigma_value, + bound_type=self.bound_type, + norm_type=self.norm_type, + output_separate_logit=output_separate_logit, + ) + + if self.self_supervised_learning_loss: + # self_supervised_learning_loss related network proposed in EfficientZero + self.projection_input_dim = latent_state_dim + self.projection = nn.Sequential( + nn.Linear(self.projection_input_dim, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, + nn.Linear(self.proj_hid, self.proj_hid), nn.BatchNorm1d(self.proj_hid), activation, + nn.Linear(self.proj_hid, self.proj_out), nn.BatchNorm1d(self.proj_out) + ) + self.prediction_head = nn.Sequential( + nn.Linear(self.proj_out, self.pred_hid), + nn.BatchNorm1d(self.pred_hid), + activation, + nn.Linear(self.pred_hid, self.pred_out), + ) + + def initial_inference(self, obs: torch.Tensor) -> EZNetworkOutput: + """ + Overview: + Initial inference of SampledEfficientZero model, which is the first step of the SampledEfficientZero model. + To perform the initial inference, we first use the representation network to obtain the "latent_state" of the observation. + Then we use the prediction network to predict the "value" and "policy_logits" of the "latent_state", and + also prepare the zeros-like ``reward_hidden_state`` for the next step of the Sampled EfficientZero model. + Arguments: + - obs (:obj:`torch.Tensor`): The 1D vector observation data. + Returns (EZNetworkOutput): + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + - value_prefix (:obj:`torch.Tensor`): The predicted prefix sum of value for input state. \ + In initial inference, we set it to zero vector. + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The hidden state of LSTM about reward. In initial inference, \ + we set it to the zeros-like hidden state (H and C). + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, obs_shape)`, where B is batch_size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. + - value_prefix (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The shape of each element is :math:`(1, B, lstm_hidden_size)`, where B is batch_size. + """ + batch_size = obs.size(0) + obs = obs.to(torch.float32) + latent_state = self._representation(obs) + policy_logits, value = self._prediction(latent_state) + # zero initialization for reward hidden states + # (hn, cn), each element shape is (layer_num=1, batch_size, lstm_hidden_size) + reward_hidden_state = ( + torch.zeros(1, batch_size, + self.lstm_hidden_size).to(obs.device), torch.zeros(1, batch_size, + self.lstm_hidden_size).to(obs.device) + ) + return EZNetworkOutput(value, [0. for _ in range(batch_size)], policy_logits, latent_state, reward_hidden_state) + + def recurrent_inference( + self, latent_state: torch.Tensor, reward_hidden_state: torch.Tensor, action: torch.Tensor + ) -> EZNetworkOutput: + """ + Overview: + Recurrent inference of Sampled EfficientZero model, which is the rollout step of the Sampled EfficientZero model. + To perform the recurrent inference, we first use the dynamics network to predict ``next_latent_state``, + ``reward_hidden_state``, ``value_prefix`` by the given current ``latent_state`` and ``action``. + We then use the prediction network to predict the ``value`` and ``policy_logits``. + Arguments: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The input hidden state of LSTM about reward. + - action (:obj:`torch.Tensor`): The predicted action to rollout. + Returns (EZNetworkOutput): + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + - value_prefix (:obj:`torch.Tensor`): The predicted prefix sum of value for input state. + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - next_latent_state (:obj:`torch.Tensor`): The predicted next latent state. + - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The output hidden state of LSTM about reward. + Shapes: + - action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. + - value_prefix (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - next_latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The shape of each element is :math:`(1, B, lstm_hidden_size)`, where B is batch_size. + """ + next_latent_state, reward_hidden_state, value_prefix = self._dynamics(latent_state, reward_hidden_state, action) + policy_logits, value = self._prediction(next_latent_state) + return EZNetworkOutput(value, value_prefix, policy_logits, next_latent_state, reward_hidden_state) + + def _representation(self, observation: torch.Tensor) -> Tuple[torch.Tensor]: + """ + Overview: + Use the representation network to encode the observations into latent state. + Arguments: + - obs (:obj:`torch.Tensor`): The 1D vector observation data. + Returns: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, obs_shape)`, where B is batch_size. + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + """ + latent_state = self.representation_network(observation) + if self.state_norm: + latent_state = renormalize(latent_state) + return latent_state + + def _prediction(self, latent_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Use the representation network to encode the observations into latent state. + Arguments: + - obs (:obj:`torch.Tensor`): The 1D vector observation data. + Returns: + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + Shapes: + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. + """ + policy, value = self.prediction_network(latent_state) + return policy, value + + def _dynamics(self, latent_state: torch.Tensor, reward_hidden_state: Tuple, + action: torch.Tensor) -> Tuple[torch.Tensor, Tuple[torch.Tensor], torch.Tensor]: + """ + Overview: + Concatenate ``latent_state`` and ``action`` and use the dynamics network to predict ``next_latent_state`` + ``value_prefix`` and ``next_reward_hidden_state``. + Arguments: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The input hidden state of LSTM about reward. + - action (:obj:`torch.Tensor`): The predicted action to rollout. + Returns: + - next_latent_state (:obj:`torch.Tensor`): The predicted latent state of the next timestep. + - next_reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The output hidden state of LSTM about reward. + - value_prefix (:obj:`torch.Tensor`): The predicted prefix sum of value for input state. + Shapes: + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size. + - next_latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - value_prefix (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + """ + # NOTE: the discrete action encoding type is important for some environments + + if not self.continuous_action_space: + # discrete action space + if self.discrete_action_encoding_type == 'one_hot': + # Stack latent_state with the one hot encoded action + if len(action.shape) == 1: + # (batch_size, ) -> (batch_size, 1) + # e.g., torch.Size([8]) -> torch.Size([8, 1]) + action = action.unsqueeze(-1) + + # transform action to one-hot encoding. + # action_one_hot shape: (batch_size, action_space_size), e.g., (8, 4) + action_one_hot = torch.zeros(action.shape[0], self.action_space_size, device=action.device) + # transform action to torch.int64 + action = action.long() + action_one_hot.scatter_(1, action, 1) + action_encoding = action_one_hot + elif self.discrete_action_encoding_type == 'not_one_hot': + action_encoding = action / self.action_space_size + if len(action_encoding.shape) == 1: + # (batch_size, ) -> (batch_size, 1) + # e.g., torch.Size([8]) -> torch.Size([8, 1]) + action_encoding = action_encoding.unsqueeze(-1) + else: + # continuous action space + if len(action.shape) == 1: + # (batch_size, ) -> (batch_size, 1) + # e.g., torch.Size([8]) -> torch.Size([8, 1]) + action = action.unsqueeze(-1) + elif len(action.shape) == 3: + # (batch_size, action_dim, 1) -> (batch_size, action_dim) + # e.g., torch.Size([8, 2, 1]) -> torch.Size([8, 2]) + action = action.squeeze(-1) + + action_encoding = action + + action_encoding = action_encoding.to(latent_state.device).float() + # state_action_encoding shape: (batch_size, latent_state[1] + action_dim]) or + # (batch_size, latent_state[1] + action_space_size]) depending on the discrete_action_encoding_type. + state_action_encoding = torch.cat((latent_state, action_encoding), dim=1) + + next_latent_state, next_reward_hidden_state, value_prefix = self.dynamics_network( + state_action_encoding, reward_hidden_state + ) + + if not self.state_norm: + return next_latent_state, next_reward_hidden_state, value_prefix + else: + next_latent_state_normalized = renormalize(next_latent_state) + return next_latent_state_normalized, next_reward_hidden_state, value_prefix + + def project(self, latent_state: torch.Tensor, with_grad=True) -> torch.Tensor: + """ + Overview: + Project the latent state to a lower dimension to calculate the self-supervised loss, which is proposed in EfficientZero. + For more details, please refer to the paper ``Exploring Simple Siamese Representation Learning``. + Arguments: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - with_grad (:obj:`bool`): Whether to calculate gradient for the projection result. + Returns: + - proj (:obj:`torch.Tensor`): The result embedding vector of projection operation. + Shapes: + - latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. + - proj (:obj:`torch.Tensor`): :math:`(B, projection_output_dim)`, where B is batch_size. + + Examples: + >>> latent_state = torch.randn(256, 64) + >>> output = self.project(latent_state) + >>> output.shape # (256, 1024) + """ + proj = self.projection(latent_state) + + if with_grad: + # with grad, use prediction_head + return self.prediction_head(proj) + else: + return proj.detach() + + def get_params_mean(self): + return get_params_mean(self) + + +class PredictionNetworkMD(nn.Module): + + def __init__( + self, + agent_num, + single_agent_action_size, + continuous_action_space, + action_space_size, + num_channels, + common_layer_num: int = 2, + fc_value_layers: SequenceType = [32], + output_support_size: int = 601, + last_linear_layer_init_zero: bool = True, + activation: Optional[nn.Module] = nn.ReLU(inplace=True), + # ============================================================== + # specific sampled related config + # ============================================================== + sigma_type='conditioned', + fixed_sigma_value: float = 0.3, + bound_type: str = None, + norm_type: str = 'BN', + output_separate_logit: bool = False, + ): + """ + Overview: + The definition of policy and value prediction network, which is used to predict value and policy by the + given latent state. + The networks are mainly built on fully connected layers. + Arguments: + - agent_num (:obj:`int`): The number of agents in the environment. + - single_agent_action_size (:obj:`int`): The number of actions for each agent. + - continuous_action_space (:obj:`bool`): The type of action space. default set it to False. + - action_space_size: (:obj:`int`): Action space size, usually an integer number. For discrete action \ + space, it is the number of discrete actions. For continuous action space, it is the dimension of \ + continuous action. + - num_channels (:obj:`int`): The num of channels in latent states. + - num_res_blocks (:obj:`int`): The number of res blocks. + - fc_value_layers (:obj:`SequenceType`): hidden layers of the value prediction head (MLP head). + - output_support_size (:obj:`int`): dim of value output. + - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of value/policy mlp, default sets it to True. + - output_separate_logit (:obj:`bool`): Whether to output separate logit for each action. + # ============================================================== + # specific sampled related config + # ============================================================== + # see ``ReparameterizationHead`` in ``ding.model.common.head`` for more details about thee following arguments. + - sigma_type (:obj:`str`): the type of sigma in policy head of prediction network, options={'conditioned', 'fixed'}. + - fixed_sigma_value (:obj:`float`): the fixed sigma value in policy head of prediction network, + - bound_type (:obj:`str`): The type of bound in networks. default set it to None. + - norm_type (:obj:`str`): The type of normalization in networks. default set it to 'BN'. + """ + super().__init__() + self.num_channels = num_channels + self.continuous_action_space = continuous_action_space + self.norm_type = norm_type + self.sigma_type = sigma_type + self.fixed_sigma_value = fixed_sigma_value + self.bound_type = bound_type + self.action_space_size = action_space_size + if self.continuous_action_space: + self.action_encoding_dim = self.action_space_size + else: + self.action_encoding_dim = 1 + + # ******* common backbone ****** + self.fc_prediction_common = MLP( + in_channels=self.num_channels, + hidden_channels=self.num_channels, + out_channels=self.num_channels, + layer_num=common_layer_num, + activation=activation, + norm_type=norm_type, + output_activation=True, + output_norm=True, + # last_linear_layer_init_zero=False is important for convergence + last_linear_layer_init_zero=False, + ) + + # ******* value and policy head ****** + self.fc_value_head = MLP( + in_channels=self.num_channels, + hidden_channels=fc_value_layers[0], + out_channels=output_support_size, + layer_num=2, + activation=activation, + norm_type=norm_type, + output_activation=False, + output_norm=False, + # last_linear_layer_init_zero=True is beneficial for convergence speed. + last_linear_layer_init_zero=last_linear_layer_init_zero + ) + + # sampled related core code + if self.continuous_action_space: + self.policy_multi_head = MultiHead( + head_cls=ReparameterizationHead, + hidden_size=self.num_channels, + output_size_list=[single_agent_action_size for _ in range(agent_num)], + layer_num=2, + sigma_type=self.sigma_type, + fixed_sigma_value=self.fixed_sigma_value, + activation=nn.ReLU(), + norm_type=None, + bound_type=self.bound_type + ) + else: + self.policy_multi_head = MultiHead( + head_cls=DiscreteHead, + hidden_size=self.num_channels, + output_size_list=[single_agent_action_size for _ in range(agent_num)], + ) + self.output_separate_logit = output_separate_logit + + def forward(self, latent_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Forward computation of the prediction network. + Arguments: + - latent_state (:obj:`torch.Tensor`): input tensor with shape (B, in_channels). + Returns: + - policy (:obj:`torch.Tensor`): policy tensor. If action space is discrete, shape is (B, action_space_size). + If action space is continuous, shape is (B, action_space_size * 2). + - value (:obj:`torch.Tensor`): value tensor with shape (B, output_support_size). + """ + x_prediction_common = self.fc_prediction_common(latent_state) + value = self.fc_value_head(x_prediction_common) + + # sampled related core code + if not self.continuous_action_space: + # policy_list: {'logit': [policy1, policy2, ...],} + # policyi shape: (B, action_space_size) + policy_list = self.policy_multi_head(x_prediction_common)['logit'] + if not self.output_separate_logit: + # The joint action space policy is the product of each agent policy + # policy shape: (B, action_space_size^^agent_num) + batch_size = latent_state.size(0) + joint_logits_batches = [] + for i in range(batch_size): + current_batch = [policy[i] for policy in policy_list] + cartesian_prod_result = torch.cartesian_prod(*current_batch) + joint_logits = cartesian_prod_result.prod(dim=1) + joint_logits_batches.append(joint_logits) + policy = torch.stack(joint_logits_batches) + else: + # policy_list: [policy1, policy2, ...] + # policy sahpe: (B, agent_num, action_space_size) + policy = torch.stack(policy_list, dim=1) + elif self.continuous_action_space: + # policy = torch.cat([policy['mu'], policy['sigma']], dim=-1) + # TODO(rjy): complete the continuous action space policy + pass + + return policy, value diff --git a/lzero/model/tests/test_common_gcn.py b/lzero/model/tests/test_common_gcn.py new file mode 100644 index 000000000..e8ec41d1c --- /dev/null +++ b/lzero/model/tests/test_common_gcn.py @@ -0,0 +1,104 @@ +import torch +import numpy as np +from torch import nn +from lzero.model.common_gcn import RepresentationNetworkGCN, RGCNLayer + +# ... + + +class TestLightZeroEnvWrapper: + + # ... + def test_representation_network_gcn_with_dict_obs(self): + robot_state_dim = 10 + human_state_dim = 5 + robot_num = 3 + human_num = 2 + hidden_channels = 64 + layer_num = 2 + activation = nn.ReLU(inplace=True) + last_linear_layer_init_zero = True + norm_type = 'BN' + + representation_network = RepresentationNetworkGCN( + robot_state_dim=robot_state_dim, + human_state_dim=human_state_dim, + robot_num=robot_num, + human_num=human_num, + hidden_channels=hidden_channels, + layer_num=layer_num, + activation=activation, + last_linear_layer_init_zero=last_linear_layer_init_zero, + norm_type=norm_type, + ) + + # Create dummy input + batch_size = 4 + x = { + 'robot_state': torch.randn(batch_size, robot_num, robot_state_dim), + 'human_state': torch.randn(batch_size, human_num, human_state_dim) + } + + # Forward pass + output = representation_network(x) + + # Check output shape + assert output.shape == (batch_size, hidden_channels) + + # Check output type + assert isinstance(output, torch.Tensor) + + # Check intermediate shape + assert representation_network.rgcn(x).shape == (batch_size, robot_num + human_num, hidden_channels) + + # Check intermediate type + assert isinstance(representation_network.rgcn(x), torch.Tensor) + + def test_representation_network_gcn_with_2d_array_obs(self): + robot_state_dim = 10 + human_state_dim = 10 # 2d_array_obs, so the dimensions must be the same + robot_num = 3 + human_num = 2 + hidden_channels = 64 + layer_num = 2 + activation = nn.ReLU(inplace=True) + last_linear_layer_init_zero = True + norm_type = 'BN' + + representation_network = RepresentationNetworkGCN( + robot_state_dim=robot_state_dim, + human_state_dim=human_state_dim, + robot_num=robot_num, + human_num=human_num, + hidden_channels=hidden_channels, + layer_num=layer_num, + activation=activation, + last_linear_layer_init_zero=last_linear_layer_init_zero, + norm_type=norm_type, + ) + + # Create dummy input + batch_size = 4 + x = torch.randn(batch_size, robot_num + human_num, robot_state_dim) + + # Forward pass + output = representation_network(x) + + # Check output shape + assert output.shape == (batch_size, hidden_channels) + + # Check output type + assert isinstance(output, torch.Tensor) + + # Check intermediate shape + assert representation_network.rgcn(x).shape == (batch_size, robot_num + human_num, hidden_channels) + + # Check intermediate type + assert isinstance(representation_network.rgcn(x), torch.Tensor) + + +if __name__ == '__main__': + test = TestLightZeroEnvWrapper() + test.test_representation_network_gcn_with_dict_obs() + test.test_representation_network_gcn_with_2d_array_obs() + print("All tests passed.") diff --git a/lzero/model/tests/test_rgcn.py b/lzero/model/tests/test_rgcn.py new file mode 100644 index 000000000..836f36b2e --- /dev/null +++ b/lzero/model/tests/test_rgcn.py @@ -0,0 +1,56 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from itertools import product +import unittest +from lzero.model.common_gcn import RGCNLayer + + +class TestRGCNLayer(unittest.TestCase): + + def setUp(self): + self.robot_state_dim = 10 + self.human_state_dim = 10 + self.similarity_function = 'embedded_gaussian' + self.batch_size = 4 + self.num_nodes = 5 # Suppose 5 robots and 5 humans + + # Create a RGCNLayer object + self.rgcn_layer = RGCNLayer( + robot_state_dim=self.robot_state_dim, + human_state_dim=self.human_state_dim, + similarity_function=self.similarity_function, + num_layer=2, + X_dim=32, + layerwise_graph=False, + skip_connection=True + ) + + # Creating dummy inputs + self.state = { + 'robot_state': torch.randn(self.batch_size, self.num_nodes, self.robot_state_dim), + 'human_state': torch.randn(self.batch_size, self.num_nodes, self.human_state_dim) + } + + def test_forward_shape(self): + # Forward pass + output = self.rgcn_layer(self.state) + expected_shape = (self.batch_size, self.num_nodes * 2, 32) # Since final_state_dim is set to X_dim + self.assertEqual(output.shape, expected_shape, "Output shape is incorrect.") + + def test_similarity_function(self): + # Check if the similarity matrix computation is working as expected + # This only checks for one similarity function due to space constraints + if self.similarity_function == 'embedded_gaussian': + X = torch.randn(self.batch_size, self.num_nodes * 2, 32) + A = self.rgcn_layer.compute_similarity_matrix(X) + self.assertEqual( + A.shape, (self.batch_size, self.num_nodes * 2, self.num_nodes * 2), + "Similarity matrix shape is incorrect." + ) + self.assertTrue(torch.all(A >= 0) and torch.all(A <= 1), "Similarity matrix values should be normalized.") + + +# Running the tests +if __name__ == '__main__': + unittest.main() diff --git a/lzero/policy/efficientzero.py b/lzero/policy/efficientzero.py index 3a94baf51..7c0e50c42 100644 --- a/lzero/policy/efficientzero.py +++ b/lzero/policy/efficientzero.py @@ -218,6 +218,8 @@ def default_model(self) -> Tuple[str, List[str]]: return 'EfficientZeroModel', ['lzero.model.efficientzero_model'] elif self._cfg.model.model_type == "mlp": return 'EfficientZeroModelMLP', ['lzero.model.efficientzero_model_mlp'] + elif self._cfg.model.model_type == "mlp_md": + return 'EfficientZeroModelMD', ['lzero.model.efficientzero_model_md'] else: raise ValueError("model type {} is not supported".format(self._cfg.model.model_type)) @@ -368,7 +370,9 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: target_normalized_visit_count_masked = torch.index_select( target_normalized_visit_count_init_step, 0, non_masked_indices ) - target_policy_entropy = -((target_normalized_visit_count_masked+1e-6) * (target_normalized_visit_count_masked+1e-6).log()).sum(-1).mean() + target_policy_entropy = -( + (target_normalized_visit_count_masked + 1e-6) * (target_normalized_visit_count_masked + 1e-6).log() + ).sum(-1).mean() else: # Set target_policy_entropy to log(|A|) if all rows are masked target_policy_entropy = torch.log(torch.tensor(target_normalized_visit_count_init_step.shape[-1])) @@ -435,7 +439,9 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: target_normalized_visit_count_masked = torch.index_select( target_normalized_visit_count, 0, non_masked_indices ) - target_policy_entropy += -((target_normalized_visit_count_masked+1e-6) * (target_normalized_visit_count_masked+1e-6).log()).sum(-1).mean() + target_policy_entropy += -( + (target_normalized_visit_count_masked + 1e-6) * (target_normalized_visit_count_masked + 1e-6).log() + ).sum(-1).mean() else: # Set target_policy_entropy to log(|A|) if all rows are masked target_policy_entropy += torch.log(torch.tensor(target_normalized_visit_count.shape[-1])) @@ -576,8 +582,7 @@ def _forward_collect( pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() latent_state_roots = latent_state_roots.detach().cpu().numpy() reward_hidden_state_roots = ( - reward_hidden_state_roots[0].detach().cpu().numpy(), - reward_hidden_state_roots[1].detach().cpu().numpy() + reward_hidden_state_roots[0].detach().cpu().numpy(), reward_hidden_state_roots[1].detach().cpu().numpy() ) policy_logits = policy_logits.detach().cpu().numpy().tolist() @@ -647,7 +652,13 @@ def _init_eval(self) -> None: else: self._mcts_eval = MCTSPtree(self._cfg) - def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, ready_env_id: np.array = None,): + def _forward_eval( + self, + data: torch.Tensor, + action_mask: list, + to_play: -1, + ready_env_id: np.array = None, + ): """ Overview: The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. diff --git a/lzero/policy/muzero.py b/lzero/policy/muzero.py index 0acc66b07..d5597db57 100644 --- a/lzero/policy/muzero.py +++ b/lzero/policy/muzero.py @@ -222,6 +222,10 @@ def default_model(self) -> Tuple[str, List[str]]: return 'MuZeroModel', ['lzero.model.muzero_model'] elif self._cfg.model.model_type == "mlp": return 'MuZeroModelMLP', ['lzero.model.muzero_model_mlp'] + elif self._cfg.model.model_type == "rgcn": + return 'MuZeroModelGCN', ['lzero.model.muzero_model_gcn'] + elif self._cfg.model.model_type == "mlp_md": + return 'MuZeroModelMD', ['lzero.model.muzero_model_md'] else: raise ValueError("model type {} is not supported".format(self._cfg.model.model_type)) @@ -441,9 +445,9 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # ============================================================== # weighted loss with masks (some invalid states which are out of trajectory.) loss = ( - self._cfg.ssl_loss_weight * consistency_loss + self._cfg.policy_loss_weight * policy_loss + - self._cfg.value_loss_weight * value_loss + self._cfg.reward_loss_weight * reward_loss + - self._cfg.policy_entropy_loss_weight * policy_entropy_loss + self._cfg.ssl_loss_weight * consistency_loss + self._cfg.policy_loss_weight * policy_loss + + self._cfg.value_loss_weight * value_loss + self._cfg.reward_loss_weight * reward_loss + + self._cfg.policy_entropy_loss_weight * policy_entropy_loss ) weighted_total_loss = (weights * loss).mean() @@ -453,8 +457,9 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in weighted_total_loss.backward() if self._cfg.multi_gpu: self.sync_gradients(self._learn_model) - total_grad_norm_before_clip = torch.nn.utils.clip_grad_norm_(self._learn_model.parameters(), - self._cfg.grad_clip_value) + total_grad_norm_before_clip = torch.nn.utils.clip_grad_norm_( + self._learn_model.parameters(), self._cfg.grad_clip_value + ) self._optimizer.step() if self._cfg.lr_piecewise_constant_decay: self.lr_scheduler.step() @@ -477,7 +482,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in 'weighted_total_loss': weighted_total_loss.item(), 'total_loss': loss.mean().item(), 'policy_loss': policy_loss.mean().item(), - 'policy_entropy': - policy_entropy_loss.mean().item() / (self._cfg.num_unroll_steps + 1), + 'policy_entropy': -policy_entropy_loss.mean().item() / (self._cfg.num_unroll_steps + 1), 'reward_loss': reward_loss.mean().item(), 'value_loss': value_loss.mean().item(), 'consistency_loss': consistency_loss.mean().item() / self._cfg.num_unroll_steps, @@ -641,13 +646,21 @@ def _get_target_obs_index_in_step_k(self, step): if self._cfg.model.model_type == 'conv': beg_index = self._cfg.model.image_channel * step end_index = self._cfg.model.image_channel * (step + self._cfg.model.frame_stack_num) - elif self._cfg.model.model_type == 'mlp': + elif self._cfg.model.model_type == 'mlp' or self._cfg.model.model_type == 'mlp_md': + beg_index = self._cfg.model.observation_shape * step + end_index = self._cfg.model.observation_shape * (step + self._cfg.model.frame_stack_num) + elif self._cfg.model.model_type == 'rgcn': beg_index = self._cfg.model.observation_shape * step end_index = self._cfg.model.observation_shape * (step + self._cfg.model.frame_stack_num) return beg_index, end_index - def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, - ready_env_id: np.array = None, ) -> Dict: + def _forward_eval( + self, + data: torch.Tensor, + action_mask: list, + to_play: int = -1, + ready_env_id: np.array = None, + ) -> Dict: """ Overview: The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. @@ -670,7 +683,10 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1 ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. """ self._eval_model.eval() - active_eval_env_num = data.shape[0] + if type(data) is dict: + active_eval_env_num = data['robot_state'].shape[0] + else: + active_eval_env_num = data.shape[0] with torch.no_grad(): # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} network_output = self._collect_model.initial_inference(data) diff --git a/lzero/policy/sampled_efficientzero.py b/lzero/policy/sampled_efficientzero.py index 7003f6808..761eaa96b 100644 --- a/lzero/policy/sampled_efficientzero.py +++ b/lzero/policy/sampled_efficientzero.py @@ -234,6 +234,8 @@ def default_model(self) -> Tuple[str, List[str]]: return 'SampledEfficientZeroModel', ['lzero.model.sampled_efficientzero_model'] elif self._cfg.model.model_type == "mlp": return 'SampledEfficientZeroModelMLP', ['lzero.model.sampled_efficientzero_model_mlp'] + elif self._cfg.model.model_type == "mlp_md": + return 'SampledEfficientZeroModelMD', ['lzero.model.sampled_efficientzero_model_md'] else: raise ValueError("model type {} is not supported".format(self._cfg.model.model_type)) @@ -498,9 +500,9 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: # ============================================================== # weighted loss with masks (some invalid states which are out of trajectory.) loss = ( - self._cfg.ssl_loss_weight * consistency_loss + self._cfg.policy_loss_weight * policy_loss + - self._cfg.value_loss_weight * value_loss + self._cfg.reward_loss_weight * value_prefix_loss + - self._cfg.policy_entropy_loss_weight * policy_entropy_loss + self._cfg.ssl_loss_weight * consistency_loss + self._cfg.policy_loss_weight * policy_loss + + self._cfg.value_loss_weight * value_loss + self._cfg.reward_loss_weight * value_prefix_loss + + self._cfg.policy_entropy_loss_weight * policy_entropy_loss ) weighted_total_loss = (weights * loss).mean() @@ -552,33 +554,37 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: } if self._cfg.model.continuous_action_space: - return_data.update({ - # ============================================================== - # sampled related core code - # ============================================================== - 'policy_mu_max': mu[:, 0].max().item(), - 'policy_mu_min': mu[:, 0].min().item(), - 'policy_mu_mean': mu[:, 0].mean().item(), - 'policy_sigma_max': sigma.max().item(), - 'policy_sigma_min': sigma.min().item(), - 'policy_sigma_mean': sigma.mean().item(), - # take the fist dim in action space - 'target_sampled_actions_max': target_sampled_actions[:, :, 0].max().item(), - 'target_sampled_actions_min': target_sampled_actions[:, :, 0].min().item(), - 'target_sampled_actions_mean': target_sampled_actions[:, :, 0].mean().item(), - 'total_grad_norm_before_clip': total_grad_norm_before_clip.item() - }) + return_data.update( + { + # ============================================================== + # sampled related core code + # ============================================================== + 'policy_mu_max': mu[:, 0].max().item(), + 'policy_mu_min': mu[:, 0].min().item(), + 'policy_mu_mean': mu[:, 0].mean().item(), + 'policy_sigma_max': sigma.max().item(), + 'policy_sigma_min': sigma.min().item(), + 'policy_sigma_mean': sigma.mean().item(), + # take the fist dim in action space + 'target_sampled_actions_max': target_sampled_actions[:, :, 0].max().item(), + 'target_sampled_actions_min': target_sampled_actions[:, :, 0].min().item(), + 'target_sampled_actions_mean': target_sampled_actions[:, :, 0].mean().item(), + 'total_grad_norm_before_clip': total_grad_norm_before_clip.item() + } + ) else: - return_data.update({ - # ============================================================== - # sampled related core code - # ============================================================== - # take the fist dim in action space - 'target_sampled_actions_max': target_sampled_actions[:, :].float().max().item(), - 'target_sampled_actions_min': target_sampled_actions[:, :].float().min().item(), - 'target_sampled_actions_mean': target_sampled_actions[:, :].float().mean().item(), - 'total_grad_norm_before_clip': total_grad_norm_before_clip.item() - }) + return_data.update( + { + # ============================================================== + # sampled related core code + # ============================================================== + # take the fist dim in action space + 'target_sampled_actions_max': target_sampled_actions[:, :].float().max().item(), + 'target_sampled_actions_min': target_sampled_actions[:, :].float().min().item(), + 'target_sampled_actions_mean': target_sampled_actions[:, :].float().mean().item(), + 'total_grad_norm_before_clip': total_grad_norm_before_clip.item() + } + ) return return_data @@ -679,9 +685,9 @@ def _calculate_policy_loss_cont( if self._cfg.policy_loss_type == 'KL': # KL divergence loss: sum( p* log(p/q) ) = sum( p*log(p) - p*log(q) )= sum( p*log(p)) - sum( p*log(q) ) policy_loss += ( - torch.exp(target_log_prob_sampled_actions.detach()) * - (target_log_prob_sampled_actions.detach() - log_prob_sampled_actions) - ).sum(-1) * mask_batch[:, unroll_step] + torch.exp(target_log_prob_sampled_actions.detach()) * + (target_log_prob_sampled_actions.detach() - log_prob_sampled_actions) + ).sum(-1) * mask_batch[:, unroll_step] elif self._cfg.policy_loss_type == 'cross_entropy': # cross_entropy loss: - sum(p * log (q) ) policy_loss += -torch.sum( @@ -722,8 +728,9 @@ def _calculate_policy_loss_disc( torch.nonzero(mask_batch[:, unroll_step]).squeeze(-1) ) - target_policy_entropy = -((target_normalized_visit_count_masked + 1e-6) * ( - target_normalized_visit_count_masked + 1e-6).log()).sum(-1).mean() + target_policy_entropy = -( + (target_normalized_visit_count_masked + 1e-6) * (target_normalized_visit_count_masked + 1e-6).log() + ).sum(-1).mean() # shape: (batch_size, num_unroll_steps, num_of_sampled_actions, action_dim) -> (batch_size, # num_of_sampled_actions, action_dim) e.g. (4, 6, 20, 2) -> (4, 20, 2) @@ -767,9 +774,9 @@ def _calculate_policy_loss_disc( if self._cfg.policy_loss_type == 'KL': # KL divergence loss: sum( p* log(p/q) ) = sum( p*log(p) - p*log(q) )= sum( p*log(p)) - sum( p*log(q) ) policy_loss += ( - torch.exp(target_log_prob_sampled_actions.detach()) * - (target_log_prob_sampled_actions.detach() - log_prob_sampled_actions) - ).sum(-1) * mask_batch[:, unroll_step] + torch.exp(target_log_prob_sampled_actions.detach()) * + (target_log_prob_sampled_actions.detach() - log_prob_sampled_actions) + ).sum(-1) * mask_batch[:, unroll_step] elif self._cfg.policy_loss_type == 'cross_entropy': # cross_entropy loss: - sum(p * log (q) ) policy_loss += -torch.sum( @@ -791,8 +798,13 @@ def _init_collect(self) -> None: self._collect_mcts_temperature = 1 def _forward_collect( - self, data: torch.Tensor, action_mask: list = None, temperature: np.ndarray = 1, to_play=-1, - epsilon: float = 0.25, ready_env_id: np.array = None, + self, + data: torch.Tensor, + action_mask: list = None, + temperature: np.ndarray = 1, + to_play=-1, + epsilon: float = 0.25, + ready_env_id: np.array = None, ): """ Overview: @@ -830,8 +842,7 @@ def _forward_collect( pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() latent_state_roots = latent_state_roots.detach().cpu().numpy() reward_hidden_state_roots = ( - reward_hidden_state_roots[0].detach().cpu().numpy(), - reward_hidden_state_roots[1].detach().cpu().numpy() + reward_hidden_state_roots[0].detach().cpu().numpy(), reward_hidden_state_roots[1].detach().cpu().numpy() ) policy_logits = policy_logits.detach().cpu().numpy().tolist() @@ -931,7 +942,13 @@ def _init_eval(self) -> None: else: self._mcts_eval = MCTSPtree(self._cfg) - def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, ready_env_id: np.array = None,): + def _forward_eval( + self, + data: torch.Tensor, + action_mask: list, + to_play: -1, + ready_env_id: np.array = None, + ): """ Overview: The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. diff --git a/lzero/worker/muzero_collector.py b/lzero/worker/muzero_collector.py index 8bb5f51ba..272c2cec5 100644 --- a/lzero/worker/muzero_collector.py +++ b/lzero/worker/muzero_collector.py @@ -210,24 +210,23 @@ def _compute_priorities(self, i: int, pred_values_lst: List[float], search_value if self.policy_config.use_priority: # Calculate priorities. The priorities are the L1 losses between the predicted # values and the search values. We use 'none' as the reduction parameter, which - # means the loss is calculated for each element individually, instead of being summed or averaged. + # means the loss is calculated for each element individually, instead of being summed or averaged. # A small constant (1e-6) is added to the results to avoid zero priorities. This # is done because zero priorities could potentially cause issues in some scenarios. pred_values = torch.from_numpy(np.array(pred_values_lst[i])).to(self.policy_config.device).float().view(-1) search_values = torch.from_numpy(np.array(search_values_lst[i])).to(self.policy_config.device ).float().view(-1) - priorities = L1Loss(reduction='none' - )(pred_values, - search_values).detach().cpu().numpy() + 1e-6 + priorities = L1Loss(reduction='none')(pred_values, search_values).detach().cpu().numpy() + 1e-6 else: # priorities is None -> use the max priority for all newly collected data priorities = None return priorities - def pad_and_save_last_trajectory(self, i: int, last_game_segments: List[GameSegment], - last_game_priorities: List[np.ndarray], - game_segments: List[GameSegment], done: np.ndarray) -> None: + def pad_and_save_last_trajectory( + self, i: int, last_game_segments: List[GameSegment], last_game_priorities: List[np.ndarray], + game_segments: List[GameSegment], done: np.ndarray + ) -> None: """ Overview: Save the game segment to the pool if the current game is finished, padding it if necessary. @@ -270,12 +269,18 @@ def pad_and_save_last_trajectory(self, i: int, last_game_segments: List[GameSegm # pad over and save if self.policy_config.gumbel_algo: - last_game_segments[i].pad_over(pad_obs_lst, pad_reward_lst, pad_root_values_lst, pad_child_visits_lst, - next_segment_improved_policy=pad_improved_policy_prob) + last_game_segments[i].pad_over( + pad_obs_lst, + pad_reward_lst, + pad_root_values_lst, + pad_child_visits_lst, + next_segment_improved_policy=pad_improved_policy_prob + ) else: if self.policy_config.use_ture_chance_label_in_chance_encoder: - last_game_segments[i].pad_over(pad_obs_lst, pad_reward_lst, pad_root_values_lst, pad_child_visits_lst, - next_chances=chance_lst) + last_game_segments[i].pad_over( + pad_obs_lst, pad_reward_lst, pad_root_values_lst, pad_child_visits_lst, next_chances=chance_lst + ) else: last_game_segments[i].pad_over(pad_obs_lst, pad_reward_lst, pad_root_values_lst, pad_child_visits_lst) """ @@ -414,7 +419,7 @@ def collect(self, stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type) # stack_obs = torch.from_numpy(stack_obs).to(self.policy_config.device).float() - stack_obs = torch.from_numpy(stack_obs).to(self.policy_config.device) + stack_obs = torch.from_numpy(stack_obs).to(self.policy_config.device).float() # ============================================================== # policy forward @@ -437,10 +442,7 @@ def collect(self, if self.policy_config.gumbel_algo: improved_policy_dict_no_env_id = {k: v['improved_policy_probs'] for k, v in policy_output.items()} - completed_value_no_env_id = { - k: v['roots_completed_value'] - for k, v in policy_output.items() - } + completed_value_no_env_id = {k: v['roots_completed_value'] for k, v in policy_output.items()} # TODO(pu): subprocess actions = {} distributions_dict = {} @@ -488,8 +490,11 @@ def collect(self, distributions_dict[env_id], value_dict[env_id], root_sampled_actions_dict[env_id] ) elif self.policy_config.gumbel_algo: - game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id], - improved_policy=improved_policy_dict[env_id]) + game_segments[env_id].store_search_stats( + distributions_dict[env_id], + value_dict[env_id], + improved_policy=improved_policy_dict[env_id] + ) else: game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id]) # append a transition tuple, including a_t, o_{t+1}, r_{t}, action_mask_{t}, to_play_{t} @@ -577,6 +582,22 @@ def collect(self, 'step': self._env_info[env_id]['step'], 'visit_entropy': visit_entropies_lst[env_id] / eps_steps_lst[env_id], } + if timestep.info.get('performance_info') is not None: + # this branch is for the performance evaluation of crowdsim env + mean_aoi = timestep.info['performance_info']['mean_aoi'] + mean_transmit_data = timestep.info['performance_info']['mean_transmit_data'] + mean_energy_consumption = timestep.info['performance_info']['mean_energy_consumption'] + transmitted_data_ratio = timestep.info['performance_info']['transmitted_data_ratio'] + human_coverage = timestep.info['performance_info']['human_coverage'] + info.update( + { + 'mean_aoi': mean_aoi, + 'mean_transmit_data': mean_transmit_data, + 'mean_energy_consumption': mean_energy_consumption, + 'transmitted_data_ratio': transmitted_data_ratio, + 'human_coverage': human_coverage, + } + ) if self.policy_config.gumbel_algo: info['completed_value'] = completed_value_lst[env_id] / eps_steps_lst[env_id] collected_episode += 1 @@ -728,6 +749,22 @@ def _output_log(self, train_iter: int) -> None: 'visit_entropy': np.mean(visit_entropy), # 'each_reward': episode_reward, } + if self._episode_info[0].get('mean_aoi') is not None: + # this branch is for the performance evaluation of crowdsim env + episode_aoi = [d['mean_aoi'] for d in self._episode_info] + episode_energy_consumption = [d['mean_energy_consumption'] for d in self._episode_info] + episode_transmitted_data_ratio = [d['transmitted_data_ratio'] for d in self._episode_info] + episode_human_coverage = [d['human_coverage'] for d in self._episode_info] + mean_transmit_data = [d['mean_transmit_data'] for d in self._episode_info] + info.update( + { + 'episode_mean_aoi': np.mean(episode_aoi), + 'episode_mean_transmit_data': np.mean(mean_transmit_data), + 'episode_mean_energy_consumption': np.mean(episode_energy_consumption), + 'episode_mean_transmitted_data_ratio': np.mean(episode_transmitted_data_ratio), + 'episode_mean_human_coverage': np.mean(episode_human_coverage), + } + ) if self.policy_config.gumbel_algo: info['completed_value'] = np.mean(completed_value) self._episode_info.clear() diff --git a/lzero/worker/muzero_evaluator.py b/lzero/worker/muzero_evaluator.py index c67bb55f2..983191f0c 100644 --- a/lzero/worker/muzero_evaluator.py +++ b/lzero/worker/muzero_evaluator.py @@ -235,8 +235,9 @@ def eval( time.sleep(retry_waiting_time) self._logger.info('=' * 10 + 'Wait for all environments (subprocess) to finish resetting.' + '=' * 10) self._logger.info( - 'After sleeping {}s, the current _env_states is {}'.format(retry_waiting_time, - self._env._env_states) + 'After sleeping {}s, the current _env_states is {}'.format( + retry_waiting_time, self._env._env_states + ) ) init_obs = self._env.ready_obs @@ -343,6 +344,9 @@ def eval( self._policy.reset([env_id]) reward = t.info['eval_episode_return'] saved_info = {'eval_episode_return': t.info['eval_episode_return']} + if 'performance_info' in t.info: + # this branch is for crowdsim env + saved_info.update(t.info['performance_info']) if 'episode_info' in t.info: saved_info.update(t.info['episode_info']) eval_monitor.update_info(env_id, saved_info) @@ -367,7 +371,8 @@ def eval( ) time.sleep(retry_waiting_time) self._logger.info( - '=' * 10 + 'Wait for all environments (subprocess) to finish resetting.' + '=' * 10 + '=' * 10 + 'Wait for all environments (subprocess) to finish resetting.' + + '=' * 10 ) self._logger.info( 'After sleeping {}s, the current _env_states is {}'.format( @@ -440,9 +445,8 @@ def eval( stop_flag = episode_return >= self._stop_value and train_iter > 0 if stop_flag: self._logger.info( - "[LightZero serial pipeline] " + - "Current episode_return: {} is greater than stop_value: {}".format(episode_return, - self._stop_value) + + "[LightZero serial pipeline] " + "Current episode_return: {} is greater than stop_value: {}". + format(episode_return, self._stop_value) + ", so your MCTS/RL agent is converged, you can refer to 'log/evaluator/evaluator_logger.txt' for details." ) diff --git a/zoo/CrowdSim/__init__.py b/zoo/CrowdSim/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/zoo/CrowdSim/config/CrowdSim_efficientzero_config.py b/zoo/CrowdSim/config/CrowdSim_efficientzero_config.py new file mode 100644 index 000000000..fc81f2c46 --- /dev/null +++ b/zoo/CrowdSim/config/CrowdSim_efficientzero_config.py @@ -0,0 +1,92 @@ +from easydict import EasyDict +import os + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 25 +update_per_collect = 100 +batch_size = 256 +max_env_step = int(3e5) +reanalyze_ratio = 0. +robot_num = 2 +# different human_num for different datasets +human_num = 10 # purdue dataset +# human_num = 33 # NCSU dataset +# human_num = 92 # KAIST dataset +one_uav_action_space = [[0, 0], [30, 0], [-30, 0], [0, 30], [0, -30]] +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +CrowdSim_efficientzero_config = dict( + exp_name= + f'result/crowd_num_human/CrowdSim_efficientzero_step{max_env_step}_uav{robot_num}_human{human_num}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', + env=dict( + env_name='CrowdSim-v0', + robot_num=robot_num, + human_num=human_num, + one_uav_action_space=one_uav_action_space, + continuous=False, + manually_discretization=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + observation_shape=(robot_num + human_num) * 4, + action_space_size=(len(one_uav_action_space)) ** robot_num, + model_type='mlp', + lstm_hidden_size=256, + latent_state_dim=256, + discrete_action_encoding_type='one_hot', + # res_connection_in_dynamics=True, + norm_type='BN', + ), + cuda=True, + env_type='not_board_games', + game_segment_length=200, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='Adam', + lr_piecewise_constant_decay=False, + learning_rate=0.003, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + eval_freq=int(1e3), + replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), +) + +CrowdSim_efficientzero_config = EasyDict(CrowdSim_efficientzero_config) +main_config = CrowdSim_efficientzero_config + +CrowdSim_efficientzero_create_config = dict( + env=dict( + type='crowdsim_lightzero', + import_names=['zoo.CrowdSim.envs.CrowdSim_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='efficientzero', + import_names=['lzero.policy.efficientzero'], + ), + collector=dict( + type='episode_muzero', + import_names=['lzero.worker.muzero_collector'], + ) +) +CrowdSim_efficientzero_create_config = EasyDict(CrowdSim_efficientzero_create_config) +create_config = CrowdSim_efficientzero_create_config + +if __name__ == "__main__": + from lzero.entry import train_muzero + train_muzero([main_config, create_config], seed=0, max_env_step=max_env_step) diff --git a/zoo/CrowdSim/config/CrowdSim_muzero_config.py b/zoo/CrowdSim/config/CrowdSim_muzero_config.py new file mode 100644 index 000000000..3705d2850 --- /dev/null +++ b/zoo/CrowdSim/config/CrowdSim_muzero_config.py @@ -0,0 +1,93 @@ +from easydict import EasyDict +import os +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 25 +update_per_collect = 100 +batch_size = 256 +max_env_step = int(3e5) +reanalyze_ratio = 0. +robot_num = 2 +human_num = 10 # purdue +# human_num = 33 # NCSU +# human_num = 92 # KAIST +one_uav_action_space = [[0, 0], [30, 0], [-30, 0], [0, 30], [0, -30]] +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +CrowdSim_muzero_config = dict( + exp_name= + f'result/crowd_num_human/CrowdSim_muzero_ssl_step{max_env_step}_uav{robot_num}__human{human_num}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', + env=dict( + env_name='CrowdSim-v0', + robot_num=robot_num, + human_num=human_num, + one_uav_action_space=one_uav_action_space, + continuous=False, + manually_discretization=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + observation_shape=(robot_num + human_num) * 4, + action_space_size=(len(one_uav_action_space)) ** robot_num, + model_type='mlp', + lstm_hidden_size=256, + latent_state_dim=256, + self_supervised_learning_loss=True, # NOTE: default is False. + discrete_action_encoding_type='one_hot', + res_connection_in_dynamics=True, + norm_type='BN', + ), + cuda=True, + env_type='not_board_games', + game_segment_length=200, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='Adam', + lr_piecewise_constant_decay=False, + learning_rate=0.003, + ssl_loss_weight=2, # NOTE: default is 0. + grad_clip_value=0.5, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + eval_freq=int(1e3), + replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), +) + +CrowdSim_muzero_config = EasyDict(CrowdSim_muzero_config) +main_config = CrowdSim_muzero_config + +CrowdSim_muzero_create_config = dict( + env=dict( + type='crowdsim_lightzero', + import_names=['zoo.CrowdSim.envs.CrowdSim_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='muzero', + import_names=['lzero.policy.muzero'], + ), + collector=dict( + type='episode_muzero', + import_names=['lzero.worker.muzero_collector'], + ) +) +CrowdSim_muzero_create_config = EasyDict(CrowdSim_muzero_create_config) +create_config = CrowdSim_muzero_create_config + +if __name__ == "__main__": + from lzero.entry import train_muzero + train_muzero([main_config, create_config], seed=0, max_env_step=max_env_step) diff --git a/zoo/CrowdSim/config/__init__.py b/zoo/CrowdSim/config/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/zoo/CrowdSim/config/crowdsim_efficientzero_md_config.py b/zoo/CrowdSim/config/crowdsim_efficientzero_md_config.py new file mode 100644 index 000000000..85a799a83 --- /dev/null +++ b/zoo/CrowdSim/config/crowdsim_efficientzero_md_config.py @@ -0,0 +1,101 @@ +from easydict import EasyDict +import os +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 25 +update_per_collect = 100 +batch_size = 256 +max_env_step = int(3e5) +reanalyze_ratio = 0. +robot_num = 2 +# different human_num for different datasets +human_num = 59 # purdue dataset +# human_num = 33 # NCSU dataset +# human_num = 92 # KAIST dataset +one_uav_action_space = [[0, 0], [30, 0], [-30, 0], [0, 30], [0, -30]] +transmit_v = 20 +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +CrowdSim_efficientzero_md_config = dict( + exp_name= + f'result/new_env/new_CrowdSim_ez_md_ssl_vt{transmit_v}_step{max_env_step}_uav{robot_num}__human{human_num}_seed0', + env=dict( + env_mode='hard', + transmit_v=transmit_v, + obs_mode='1-dim-array', + env_name='CrowdSim-v0', + dataset='purdue', + robot_num=robot_num, + human_num=human_num, + one_uav_action_space=one_uav_action_space, + continuous=False, + manually_discretization=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + agent_num=robot_num, + observation_shape=(robot_num + human_num) * 4, + obs_mode='1-dim-array', + robot_state_dim=4, + human_state_dim=4, + robot_num=robot_num, + human_num=human_num, + single_agent_action_size=len(one_uav_action_space), + action_space_size=(len(one_uav_action_space)) ** robot_num, + model_type='mlp_md', + output_separate_logit=False, # not output separate logit for each action. + lstm_hidden_size=128, + latent_state_dim=128, + discrete_action_encoding_type='one_hot', + res_connection_in_dynamics=True, + norm_type='BN', + ), + cuda=True, + env_type='not_board_games', + game_segment_length=200, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='Adam', + lr_piecewise_constant_decay=False, + learning_rate=0.003, + grad_clip_value=0.5, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + eval_freq=int(2e2), + replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), +) + +CrowdSim_efficientzero_md_config = EasyDict(CrowdSim_efficientzero_md_config) +main_config = CrowdSim_efficientzero_md_config + +CrowdSim_efficientzero_md_create_config = dict( + env=dict( + type='crowdsim_lightzero', + import_names=['zoo.CrowdSim.envs.crowdsim_lightzero_env'], + ), + env_manager=dict(type='base'), + policy=dict( + type='efficientzero', + import_names=['lzero.policy.efficientzero'], + ), +) +CrowdSim_efficientzero_md_create_config = EasyDict(CrowdSim_efficientzero_md_create_config) +create_config = CrowdSim_efficientzero_md_create_config + +if __name__ == "__main__": + from lzero.entry import train_muzero + train_muzero([main_config, create_config], seed=0, max_env_step=max_env_step) diff --git a/zoo/CrowdSim/config/crowdsim_muzero_md_config.py b/zoo/CrowdSim/config/crowdsim_muzero_md_config.py new file mode 100644 index 000000000..241a67e1e --- /dev/null +++ b/zoo/CrowdSim/config/crowdsim_muzero_md_config.py @@ -0,0 +1,112 @@ +from easydict import EasyDict +import os +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 50 +# num_simulations = 25 +update_per_collect = 250 +batch_size = 256 +max_env_step = int(5e5) +reanalyze_ratio = 0. +robot_num = 2 +# different human_num for different datasets +human_num = 59 # purdue dataset +# human_num = 33 # NCSU dataset +# human_num = 92 # KAIST dataset +one_uav_action_space = [[0, 0], [30, 0], [-30, 0], [0, 30], [0, -30]] +transmit_v = 120 +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +CrowdSim_muzero_config = dict( + exp_name= + f'result/new_env2_hard/new_CrowdSim2_hard_womd_vc1_vt{transmit_v}_muzero_md_ssl_step{max_env_step}_uav{robot_num}_human{human_num}_ns{num_simulations}_upc{update_per_collect}_seed0', + env=dict( + env_mode='hard', + transmit_v=transmit_v, + collect_v_prob = {'1': 1, '2': 0}, + obs_mode='1-dim-array', + env_name='CrowdSim-v0', + dataset='purdue', + robot_num=robot_num, + human_num=human_num, + one_uav_action_space=one_uav_action_space, + continuous=False, + manually_discretization=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + # robot_observation_shape=(robot_num, 4), + # human_observation_shape=(human_num, 4), + agent_num=robot_num, + observation_shape=(robot_num + human_num) * 4, + obs_mode='1-dim-array', + robot_state_dim=4, + human_state_dim=4, + robot_num=robot_num, + human_num=human_num, + single_agent_action_size=len(one_uav_action_space), + action_space_size=(len(one_uav_action_space)) ** robot_num, + model_type='mlp_md', + output_separate_logit=False, # not output separate logit for each action. + lstm_hidden_size=256, + latent_state_dim=256, + self_supervised_learning_loss=True, # NOTE: default is False. + discrete_action_encoding_type='one_hot', + res_connection_in_dynamics=True, + norm_type='BN', + ), + cuda=True, + env_type='not_board_games', + # game_segment_length=120, + game_segment_length=200, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='Adam', + lr_piecewise_constant_decay=False, + learning_rate=0.003, + ssl_loss_weight=2, # NOTE: default is 0. + grad_clip_value=10, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + eval_freq=int(2e2), + replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), +) + +CrowdSim_muzero_config = EasyDict(CrowdSim_muzero_config) +main_config = CrowdSim_muzero_config + +CrowdSim_muzero_create_config = dict( + env=dict( + type='crowdsim_lightzero', + import_names=['zoo.CrowdSim.envs.crowdsim_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='muzero', + import_names=['lzero.policy.muzero'], + ), + collector=dict( + type='episode_muzero', + import_names=['lzero.worker.muzero_collector'], + ) +) +CrowdSim_muzero_create_config = EasyDict(CrowdSim_muzero_create_config) +create_config = CrowdSim_muzero_create_config + +if __name__ == "__main__": + from lzero.entry import train_muzero + train_muzero([main_config, create_config], seed=0, max_env_step=max_env_step) diff --git a/zoo/CrowdSim/config/crowdsim_muzero_rgcn_config.py b/zoo/CrowdSim/config/crowdsim_muzero_rgcn_config.py new file mode 100644 index 000000000..0f3dcf147 --- /dev/null +++ b/zoo/CrowdSim/config/crowdsim_muzero_rgcn_config.py @@ -0,0 +1,102 @@ +from easydict import EasyDict +import os +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 25 +update_per_collect = 100 +batch_size = 256 +max_env_step = int(3e5) +reanalyze_ratio = 0. +robot_num = 2 +# different human_num for different datasets +human_num = 59 # purdue dataset +# human_num = 33 # NCSU dataset +# human_num = 92 # KAIST dataset +one_uav_action_space = [[0, 0], [30, 0], [-30, 0], [0, 30], [0, -30]] +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +CrowdSim_muzero_config = dict( + exp_name=f'result/CrowdSim_muzerogcn_ssl_step{max_env_step}_uav{robot_num}__human{human_num}_seed0', + env=dict( + obs_mode='1-dim-array', + env_name='CrowdSim-v0', + dataset='purdue', + robot_num=robot_num, + human_num=human_num, + one_uav_action_space=one_uav_action_space, + continuous=False, + manually_discretization=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + # robot_observation_shape=(robot_num, 4), + # human_observation_shape=(human_num, 4), + observation_shape=(robot_num + human_num) * 4, + obs_mode='1-dim-array', + robot_state_dim=4, + human_state_dim=4, + robot_num=robot_num, + human_num=human_num, + action_space_size=(len(one_uav_action_space)) ** robot_num, + model_type='rgcn', + lstm_hidden_size=256, + latent_state_dim=256, + self_supervised_learning_loss=True, # NOTE: default is False. + discrete_action_encoding_type='one_hot', + res_connection_in_dynamics=True, + norm_type='BN', + ), + cuda=True, + env_type='not_board_games', + game_segment_length=200, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='Adam', + lr_piecewise_constant_decay=False, + learning_rate=0.003, + ssl_loss_weight=2, # NOTE: default is 0. + grad_clip_value=0.5, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + eval_freq=int(1e3), + replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), +) + +CrowdSim_muzero_config = EasyDict(CrowdSim_muzero_config) +main_config = CrowdSim_muzero_config + +CrowdSim_muzero_create_config = dict( + env=dict( + type='crowdsim_lightzero', + import_names=['zoo.CrowdSim.envs.crowdsim_lightzero_env'], + ), + env_manager=dict(type='base'), + policy=dict( + type='muzero', + import_names=['lzero.policy.muzero'], + ), + collector=dict( + type='episode_muzero', + import_names=['lzero.worker.muzero_collector'], + ) +) +CrowdSim_muzero_create_config = EasyDict(CrowdSim_muzero_create_config) +create_config = CrowdSim_muzero_create_config + +if __name__ == "__main__": + from lzero.entry import train_muzero + train_muzero([main_config, create_config], seed=0, max_env_step=max_env_step) diff --git a/zoo/CrowdSim/config/crowdsim_sez_md_config.py b/zoo/CrowdSim/config/crowdsim_sez_md_config.py new file mode 100644 index 000000000..a6ac190a8 --- /dev/null +++ b/zoo/CrowdSim/config/crowdsim_sez_md_config.py @@ -0,0 +1,110 @@ +from easydict import EasyDict +import os +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 25 +update_per_collect = 100 +batch_size = 256 +max_env_step = int(3e5) +reanalyze_ratio = 0. +robot_num = 2 +# different human_num for different datasets +human_num = 59 # purdue dataset +# human_num = 33 # NCSU dataset +# human_num = 92 # KAIST dataset +one_uav_action_space = [[0, 0], [30, 0], [-30, 0], [0, 30], [0, -30]] +K = 10 +transmit_v = 20 +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +CrowdSim_sez_config = dict( + exp_name= + f'result/new_env/new_CrowdSim_vt{transmit_v}_sez_md_ssl_K{K}_step{max_env_step}_uav{robot_num}__human{human_num}_seed0', + env=dict( + env_mode='hard', + obs_mode='1-dim-array', + transmit_v=transmit_v, + env_name='CrowdSim-v0', + dataset='purdue', + robot_num=robot_num, + human_num=human_num, + one_uav_action_space=one_uav_action_space, + continuous=False, + manually_discretization=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + agent_num=robot_num, + observation_shape=(robot_num + human_num) * 4, + obs_mode='1-dim-array', + robot_state_dim=4, + human_state_dim=4, + robot_num=robot_num, + human_num=human_num, + single_agent_action_size=len(one_uav_action_space), + action_space_size=(len(one_uav_action_space)) ** robot_num, + model_type='mlp_md', + output_separate_logit=False, # not output separate logit for each action. + continuous_action_space=False, + num_of_sampled_actions=K, + lstm_hidden_size=128, + latent_state_dim=128, + self_supervised_learning_loss=True, # NOTE: default is False. + discrete_action_encoding_type='one_hot', + res_connection_in_dynamics=True, + norm_type='BN', + ), + cuda=True, + env_type='not_board_games', + game_segment_length=200, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='Adam', + lr_piecewise_constant_decay=False, + learning_rate=0.003, + ssl_loss_weight=2, # NOTE: default is 0. + grad_clip_value=0.5, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + eval_freq=int(1e3), + replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), +) + +CrowdSim_sez_config = EasyDict(CrowdSim_sez_config) +main_config = CrowdSim_sez_config + +CrowdSim_sez_create_config = dict( + env=dict( + type='crowdsim_lightzero', + import_names=['zoo.CrowdSim.envs.crowdsim_lightzero_env'], + ), + env_manager=dict(type='base'), + policy=dict( + type='sampled_efficientzero', + import_names=['lzero.policy.sampled_efficientzero'], + ), + collector=dict( + type='episode_muzero', + import_names=['lzero.worker.muzero_collector'], + ) +) +CrowdSim_sez_create_config = EasyDict(CrowdSim_sez_create_config) +create_config = CrowdSim_sez_create_config + +if __name__ == "__main__": + from lzero.entry import train_muzero + train_muzero([main_config, create_config], seed=0, max_env_step=max_env_step) diff --git a/zoo/CrowdSim/entry/eval_crowdsim.py b/zoo/CrowdSim/entry/eval_crowdsim.py new file mode 100644 index 000000000..46364bdc8 --- /dev/null +++ b/zoo/CrowdSim/entry/eval_crowdsim.py @@ -0,0 +1,85 @@ +from lzero.entry import eval_muzero +import numpy as np + +if __name__ == "__main__": + """ + Overview: + Main script to evaluate the MuZero model on Atari games. The script will loop over multiple seeds, + evaluating a certain number of episodes per seed. Results are aggregated and printed. + + Variables: + - model_path (:obj:`Optional[str]`): The pretrained model path, pointing to the ckpt file of the pretrained model. + The path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``. + - seeds (:obj:`List[int]`): List of seeds to use for the evaluations. + - num_episodes_each_seed (:obj:`int`): Number of episodes to evaluate for each seed. + - total_test_episodes (:obj:`int`): Total number of test episodes, calculated as num_episodes_each_seed * len(seeds). + - returns_mean_seeds (:obj:`np.array`): Array of mean return values for each seed. + - returns_seeds (:obj:`np.array`): Array of all return values for each seed. + """ + # Importing the necessary configuration files from the atari muzero configuration in the zoo directory. + # module_path = '/home/nighoodRen/LightZero/result/new_env/new_CrowdSim_vt20_muzero_md_ssl_step300000_uav2__human59_seed0' + # import sys + # if module_path not in sys.path: + # sys.path.append(module_path) + # # 导入模块中的内容 + # from formatted_total_config import main_config, create_config + # from result.new_env.new_CrowdSim_vt20_muzero_md_ssl_step300000_uav2__human59_seed0.formatted_total_config import main_config, create_config + from zoo.CrowdSim.config.crowdsim_muzero_md_config import main_config, create_config + + # model_path is the path to the trained MuZero model checkpoint. + # If no path is provided, the script will use the default model. + model_path = 'xxx/ckpt_best.pth.tar' + main_config.exp_name = 'xxx' + 'eval' # original result folder/eval + # seeds is a list of seed values for the random number generator, used to initialize the environment. + seeds = [0] + # num_episodes_each_seed is the number of episodes to run for each seed. + num_episodes_each_seed = 1 + # total_test_episodes is the total number of test episodes, calculated as the product of the number of seeds and the number of episodes per seed + total_test_episodes = num_episodes_each_seed * len(seeds) + + # Setting the type of the environment manager to 'base' for the visualization purposes. + create_config.env_manager.type = 'base' + # The number of environments to evaluate concurrently. Set to 1 for visualization purposes. + main_config.env.evaluator_env_num = 1 + # The total number of evaluation episodes that should be run. + main_config.env.n_evaluator_episode = total_test_episodes + # A boolean flag indicating whether to render the environments in real-time. + main_config.env.render_mode_human = False + + # A boolean flag indicating whether to save the video of the environment. + main_config.env.save_replay = True + # The path where the recorded video will be saved. + main_config.env.replay_path = main_config.exp_name + '/video' # current result folder/eval + + # The maximum number of steps for each episode during evaluation. This may need to be adjusted based on the specific characteristics of the environment. + main_config.env.eval_max_episode_steps = int(20) + + # These lists will store the mean and total rewards for each seed. + returns_mean_seeds = [] + returns_seeds = [] + + # The main evaluation loop. For each seed, the MuZero model is evaluated and the mean and total rewards are recorded. + for seed in seeds: + returns_mean, returns = eval_muzero( + [main_config, create_config], + seed=seed, + num_episodes_each_seed=num_episodes_each_seed, + print_seed_details=False, + model_path=model_path + ) + print(returns_mean, returns) + returns_mean_seeds.append(returns_mean) + returns_seeds.append(returns) + + # Convert the list of mean and total rewards into numpy arrays for easier statistical analysis. + returns_mean_seeds = np.array(returns_mean_seeds) + returns_seeds = np.array(returns_seeds) + + # Printing the evaluation results. The average reward and the total reward for each seed are displayed, followed by the mean reward across all seeds. + print("=" * 20) + print( + f"We evaluated a total of {len(seeds)} seeds. For each seed, we evaluated {num_episodes_each_seed} episode(s)." + ) + print(f"For seeds {seeds}, the mean returns are {returns_mean_seeds}, and the returns are {returns_seeds}.") + print("Across all seeds, the mean reward is:", returns_mean_seeds.mean()) + print("=" * 20) diff --git a/zoo/CrowdSim/envs/CrowdSim_env.py b/zoo/CrowdSim/envs/CrowdSim_env.py new file mode 100644 index 000000000..6cc3a7209 --- /dev/null +++ b/zoo/CrowdSim/envs/CrowdSim_env.py @@ -0,0 +1,113 @@ +from typing import Union, Optional + +import gym +import numpy as np +from itertools import product +import logging + +from ding.envs import BaseEnv, BaseEnvTimestep +from ding.envs import ObsPlusPrevActRewWrapper +from ding.torch_utils import to_ndarray +from ding.utils import ENV_REGISTRY + +import zoo.CrowdSim.envs.Crowdsim.env + + +@ENV_REGISTRY.register('crowdsim_lightzero') +class CrowdSimEnv(BaseEnv): + + def __init__(self, cfg: dict = {}) -> None: + self._cfg = cfg + self._init_flag = False + self._replay_path = None + self._robot_num = self._cfg.robot_num + self._human_num = self._cfg.human_num + self._observation_space = gym.spaces.Box( + low=float("-inf"), high=float("inf"), shape=((self._robot_num + self._human_num) * 4, ), dtype=np.float32 + ) + # action space + # one_uav_action_space = [[0, 0], [30, 0], [-30, 0], [0, 30], [0, -30]] + self.real_action_space = list(product(self._cfg.one_uav_action_space, repeat=self._robot_num)) + one_uav_action_n = len(self._cfg.one_uav_action_space) + self._action_space = gym.spaces.Discrete(one_uav_action_n ** self._robot_num) + self._action_space.seed(0) # default seed + self._reward_space = gym.spaces.Box(low=0.0, high=1.0, shape=(1, ), dtype=np.float32) + self._continuous = False + + def reset(self) -> np.ndarray: + if not self._init_flag: + self._env = gym.make('CrowdSim-v0', dataset=self._cfg.dataset, custom_config=self._cfg) + self._init_flag = True + if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed: + np_seed = 100 * np.random.randint(1, 1000) + self._env.seed(self._seed + np_seed) + self._action_space.seed(self._seed + np_seed) + elif hasattr(self, '_seed'): + self._env.seed(self._seed) + self._action_space.seed(self._seed) + self._eval_episode_return = 0 + # process obs + raw_obs = self._env.reset() + obs_list = raw_obs.to_array() + # human_obs, robot_obs = obs_list + obs = np.concatenate(obs_list, axis=0).flatten() # for 1 dim e.g.(244,) + assert len(obs) == (self._robot_num + self._human_num) * 4 + action_mask = np.ones(self.action_space.n, 'int8') + obs = {'observation': obs, 'action_mask': action_mask, 'to_play': -1} + + return obs + + def close(self) -> None: + if self._init_flag: + self._env.close() + self._init_flag = False + + def seed(self, seed: int, dynamic_seed: bool = True) -> None: + self._seed = seed + self._dynamic_seed = dynamic_seed + np.random.seed(self._seed) + + def step(self, action: Union[int, np.ndarray]) -> BaseEnvTimestep: + if isinstance(action, np.ndarray) and action.shape == (1, ): + action = action.squeeze() # 0-dim array + real_action = self.real_action_space[action] + assert isinstance(real_action, tuple) and len(real_action) == self._robot_num, "illegal action!" + raw_obs, rew, done, info = self._env.step(real_action) + obs_list = to_ndarray(raw_obs.to_tensor()) + obs = np.concatenate(obs_list, axis=0).flatten() # for 1 dim e.g.(244,) + assert len(obs) == (self._robot_num + self._human_num) * 4 + + self._eval_episode_return += rew + if done: + info['eval_episode_return'] = self._eval_episode_return + # logging.INFO('one game finish!') + + action_mask = np.ones(self.action_space.n, 'int8') + obs = {'observation': obs, 'action_mask': action_mask, 'to_play': -1} + rew = to_ndarray([rew]).astype(np.float32) + return BaseEnvTimestep(obs, rew, done, info) + + def enable_save_replay(self, replay_path: Optional[str] = None) -> None: + if replay_path is None: + replay_path = './video' + self._replay_path = replay_path + + def random_action(self) -> np.ndarray: + random_action = self.action_space.sample() + random_action = to_ndarray([random_action], dtype=np.int64) + return random_action + + @property + def observation_space(self) -> gym.spaces.Space: + return self._observation_space + + @property + def action_space(self) -> gym.spaces.Space: + return self._action_space + + @property + def reward_space(self) -> gym.spaces.Space: + return self._reward_space + + def __repr__(self) -> str: + return "LightZero CrowdSim Env" diff --git a/zoo/CrowdSim/envs/Crowdsim/__init__.py b/zoo/CrowdSim/envs/Crowdsim/__init__.py new file mode 100644 index 000000000..5fb297682 --- /dev/null +++ b/zoo/CrowdSim/envs/Crowdsim/__init__.py @@ -0,0 +1,7 @@ +import logging +from gym.envs.registration import register +logger = logging.getLogger(__name__) +register( + id='CrowdSim-v0', + entry_point='zoo.CrowdSim.envs.Crowdsim.env.crowd_sim:CrowdSim', +) diff --git a/zoo/CrowdSim/envs/Crowdsim/env/__init__.py b/zoo/CrowdSim/envs/Crowdsim/env/__init__.py new file mode 100644 index 000000000..eddfe3037 --- /dev/null +++ b/zoo/CrowdSim/envs/Crowdsim/env/__init__.py @@ -0,0 +1 @@ +from .crowd_sim import CrowdSim diff --git a/zoo/CrowdSim/envs/Crowdsim/env/crowd_sim.py b/zoo/CrowdSim/envs/Crowdsim/env/crowd_sim.py new file mode 100644 index 000000000..ce0d6b15b --- /dev/null +++ b/zoo/CrowdSim/envs/Crowdsim/env/crowd_sim.py @@ -0,0 +1,463 @@ +import pandas as pd + +import logging +import random +import gym +# from shapely.geometry import Point +import numpy as np +from scipy.stats import entropy +# import folium +# from folium.plugins import TimestampedGeoJson, AntPath + +from zoo.CrowdSim.envs.Crowdsim.env.model.utils import * +from zoo.CrowdSim.envs.Crowdsim.env.model.mdp import HumanState, RobotState, JointState +from LightZero.zoo.CrowdSim.envs.Crowdsim.env.crowd_sim_base_config import get_selected_config + + +class CrowdSim(gym.Env): + """ + Overview: + LightZero version of the CrowdSim environment. This class includes methods for resetting, closing, and \ + stepping through the environment, as well as seeding for reproducibility, saving replay videos, and generating \ + random actions. It also includes properties for accessing the observation space, action space, and reward space of the \ + environment. The environment is a grid world with humans and robots moving around. The robots are tasked with \ + minimizing the average age of information (AoI) of the humans by moving to their locations and collecting data from them. \ + The humans generate data at a constant rate, and the robots have a limited energy supply that is consumed by moving. \ + The environment is divided into two modes: 'easy' and 'hard'. In the 'easy' mode, the robots can only collect data from \ + humans when they are within a certain range, and the AoI of a human is reset to 0 when a robot collects data from them. \ + In the 'hard' mode, the robots can collect data from humans even when they are not within range, and the AoI of a human \ + is not reset when a robot collects data from them. The environment is initialized with a dataset of human locations and \ + timestamps, and the robots are tasked with collecting data from the humans to minimize the average AoI. The environment \ + is considered solved when the average AoI is minimized to a certain threshold or the time limit is reached. + Interface: + `__init__`, `reset`, `step`, `render`, `sync_human_df`, `generate_human`, `generate_robot`. + """ + metadata = {'render.modes': ['human']} + + def __init__(self, dataset, custom_config=None): + """ + Overview: + Initialize the environment with a dataset and a custom configuration. The dataset contains the locations and \ + timestamps of the humans, and the custom configuration contains the environment mode, number of humans, number \ + of robots, maximum timestep, step time, start timestamp, and maximum UAV energy. The environment is divided into \ + two modes: 'easy' and 'hard'. In the 'easy' mode, the robots can only collect data from humans when they are within \ + a certain range, and the AoI of a human is reset to 0 when a robot collects data from them. In the 'hard' mode, the \ + robots can collect data from humans even when they are not within range, and the AoI of a human is not reset when a \ + robot collects data from them. The environment is initialized with a dataset of human locations and timestamps, and \ + the robots are tasked with collecting data from the humans to minimize the average AoI. The environment is considered \ + solved when the average AoI is minimized to a certain threshold or the time limit is reached. + Args: + - dataset (:obj:`str`): The path to the dataset file. + - custom_config (:obj:`dict`): A dictionary containing the custom configuration for the environment. \ + The custom configuration should include the following keys: + - env_mode (:obj:`str`): The environment mode ('easy' or 'hard'). + - human_num (:obj:`int`): The number of humans in the environment. + - robot_num (:obj:`int`): The number of robots in the environment. + - num_timestep (:obj:`int`): The maximum timestep for the environment. + - step_time (:obj:`float`): The time per step in seconds. + - start_timestamp (:obj:`int`): The start timestamp for the environment. + - max_uav_energy (:obj:`float`): The maximum energy for the UAVs. + """ + # mcfg should include: + self.time_limit = None + self.robots = None + self.humans = None + self.agent = None + self.current_timestep = None + self.phase = None + + self.config = get_selected_config(dataset) + self.config.update(custom_config) + + self.env_mode = self.config.env_mode # 'easy' or 'hard' + self.human_num = self.config.human_num + self.robot_num = self.config.robot_num + self.num_timestep = self.config.num_timestep # max timestep + self.step_time = self.config.step_time # second per step + self.start_timestamp = self.config.start_timestamp # fit timpestamp to datetime + self.max_uav_energy = self.config.max_uav_energy + # self.action_space = gym.spaces.Discrete(4**self.robot_num) # for each robot have 4 actions(up, down, left, right), then product + self.action_space = gym.spaces.Discrete(len(self.config.one_uav_action_space)) + # human obs: [px, py, remaining_data_amount, aoi] + # robot obs: [px, py, theta, energy] + self.observation_space = gym.spaces.Box( + low=float("-inf"), high=float("inf"), shape=(self.robot_num + self.human_num, 4), dtype=np.float32 + ) + + # load_dataset + self.transmit_v = self.config.transmit_v # 5*0.3Mb/s + self.nlon = self.config.nlon + self.nlat = self.config.nlat + self.lower_left = self.config.lower_left + self.upper_right = self.config.upper_right + self.human_df = pd.read_csv(self.config.dataset_dir) + logging.info("Finished reading {} rows".format(len(self.human_df))) + + self.human_df['t'] = pd.to_datetime(self.human_df['timestamp'], unit='s') # 's' stands for second + self.human_df['aoi'] = -1 # 加入aoi记录aoi + self.human_df['data_amount'] = -1 # record the remaining data amount of each human + self.human_df['energy'] = -1 # 加入energy记录energy + logging.info('Env mode: {}'.format(self.env_mode)) + logging.info('human number: {}'.format(self.human_num)) + logging.info('Robot number: {}'.format(self.robot_num)) + + # for debug + self.current_human_aoi_list = np.zeros([ + self.human_num, + ]) + self.mean_aoi_timelist = np.zeros([ + self.config.num_timestep + 1, + ]) + self.cur_data_amount_timelist = np.zeros([ + self.human_num, + ]) + self.robot_energy_timelist = np.zeros([self.config.num_timestep + 1, self.robot_num]) + self.robot_x_timelist = np.zeros([self.config.num_timestep + 1, self.robot_num]) + self.robot_y_timelist = np.zeros([self.config.num_timestep + 1, self.robot_num]) + self.update_human_timelist = np.zeros([ + self.config.num_timestep, + ]) + self.data_transmission = 0 + self.data_collection_distribution = np.zeros(self.human_num) + self.data_transmission_distribution = np.zeros(self.human_num) + + def generate_human(self, human_id, selected_data, selected_next_data): + """ + Overview: + Generate a human with the given id, selected data, and selected next data. The human is initialized with \ + the given data and next data, and the remaining data amount is set to 0. The human is also initialized with \ + an AoI of 0. + Argments: + - human_id (:obj:`int`): The id of the human. + - selected_data (:obj:`pd.DataFrame`): The selected data for the current timestep. + - selected_next_data (:obj:`pd.DataFrame`): The selected data for the next timestep. + Returns: + - human (:obj:`Human`): The generated human. + """ + human = Human(human_id, self.config) + px, py, theta = get_human_position_from_list( + self.current_timestep, human_id, selected_data, selected_next_data, self.config + ) + # human obs: [px, py, data_amount, aoi] + human.set(px, py, theta, 0, 0) # initial aoi of human is 0 + return human + + def generate_robot(self, robot_id): + """ + Overview: + Generate a robot with the given id. The robot is initialized with the given id and the maximum UAV energy. + Argments: + - robot_id (:obj:`int`): The id of the robot. + Returns: + - robot (:obj:`Robot`): The generated robot. + """ + robot = Robot(robot_id, self.config) + # robot obs: [px, py, theta, energy] + robot.set(self.nlon / 2, self.nlat / 2, 0, self.max_uav_energy) # robot有energy + return robot + + def sync_human_df(self, human_id, current_timestep, aoi, data_amount): + """ + Overview: + Sync the human_df with the current timestep and aoi. + Args: + - human_id (:obj:`int`): The id of the human. + - current_timestep (:obj:`int`): The current timestep. + - aoi (:obj:`int`): The aoi of the human. + """ + current_timestamp = self.start_timestamp + current_timestep * self.step_time + current_index = self.human_df[(self.human_df.id == human_id) + & (self.human_df.timestamp == current_timestamp)].index + # self.human_df.loc[current_index, "aoi"] = aoi # slower + self.human_df.iat[current_index.values[0], 9] = aoi # faster + self.human_df.iat[current_index.values[0], 10] = data_amount + + def reset(self, phase='test', test_case=None): + """ + Overview: + Reset the environment to the initial state. The environment is reset to the start timestamp, and the humans \ + and robots are generated with the given data. The humans are initialized with the selected data and next data, \ + and the robots are initialized with the given id. The environment is also initialized with the current timestep, \ + mean AoI, robot energy, robot x, robot y, and update human timelist. The environment is considered solved when \ + the average AoI is minimized to a certain threshold or the time limit is reached. + Argments: + - phase (:obj:`str`): The phase of the environment ('train' or 'test'). + - test_case (:obj:`int`): The test case for the environment. + Returns: + - state (:obj:`JointState`): The initial state of the environment. + """ + self.current_timestep = 0 + + # generate human + self.humans = [] + selected_data, selected_next_data = get_human_position_list(self.current_timestep, self.human_df, self.config) + self.generate_data_amount_per_step = 0 + self.total_generated_data_amount = 0 + for human_id in range(self.human_num): + self.humans.append(self.generate_human(human_id, selected_data, selected_next_data)) + self.generate_data_amount_per_step += self.humans[human_id].collect_v + self.sync_human_df(human_id, self.current_timestep, aoi=0, data_amount=0) + + # generate robot + self.robots = [] + for robot_id in range(self.robot_num): + self.robots.append(self.generate_robot(robot_id)) + + self.cur_data_amount_timelist = np.zeros([ + self.human_num, + ]) + self.current_human_aoi_list = np.zeros([ + self.human_num, + ]) + self.mean_aoi_timelist = np.zeros([ + self.config.num_timestep + 1, + ]) + self.mean_aoi_timelist[self.current_timestep] = np.mean(self.current_human_aoi_list) + self.robot_energy_timelist = np.zeros([self.config.num_timestep + 1, self.robot_num]) + self.robot_energy_timelist[self.current_timestep, :] = self.max_uav_energy + self.robot_x_timelist = np.zeros([self.config.num_timestep + 1, self.robot_num]) + self.robot_x_timelist[self.current_timestep, :] = self.nlon / 2 + self.robot_y_timelist = np.zeros([self.config.num_timestep + 1, self.robot_num]) + self.robot_y_timelist[self.current_timestep, :] = self.nlat / 2 + self.update_human_timelist = np.zeros([ + self.config.num_timestep, + ]) + self.data_transmission = 0 + self.data_collection_distribution = np.zeros(self.human_num) + self.data_transmission_distribution = np.zeros(self.human_num) + + # for visualization + self.plot_states = [] + self.robot_actions = [] + self.rewards = [] + self.aoi_rewards = [] + self.energy_rewards = [] + self.action_values = [] + self.plot_states.append( + [[robot.get_obs() for robot in self.robots], [human.get_obs() for human in self.humans]] + ) + + state = JointState([robot.get_obs() for robot in self.robots], [human.get_obs() for human in self.humans]) + return state + + def step(self, action): + """ + Overview: + Perform a step in the environment using the provided action, and return the next state of the environment. \ + The next state is encapsulated in a BaseEnvTimestep object, which includes the new observation, reward, done flag, \ + and info dictionary. The cumulative reward (`_eval_episode_return`) is updated with the reward obtained in this step. \ + If the episode ends (done is True), the total reward for the episode is stored in the info dictionary. + Argments: + - action (:obj:`Union[int, np.ndarray]`): The action to be performed in the environment. If the action is a 1-dimensional \ + numpy array, it is squeezed to a 0-dimension array. + Returns: + - next_state (:obj:`JointState`): The next state of the environment. + - reward (:obj:`float`): The reward obtained in this step. + - done (:obj:`bool`): A flag indicating whether the episode has ended. + - info (:obj:`dict`): A dictionary containing additional information about the environment. + """ + new_robot_position = np.zeros([self.robot_num, 2]) + current_enenrgy_consume = np.zeros([ + self.robot_num, + ]) + + num_updated_human = 0 # number of humans whose AoI is updated + + for robot_id, robot in enumerate(self.robots): + new_robot_px = robot.px + action[robot_id][0] + new_robot_py = robot.py + action[robot_id][1] + robot_theta = get_theta(0, 0, action[robot_id][0], action[robot_id][1]) + # print(action[robot_id], robot_theta) + is_stopping = True if (action[robot_id][0] == 0 and action[robot_id][1] == 0) else False + is_collide = True if judge_collision(new_robot_px, new_robot_py, robot.px, robot.py, self.config) else False + + if is_stopping is True: + consume_energy = consume_uav_energy(0, self.step_time, self.config) + else: + consume_energy = consume_uav_energy(self.step_time, 0, self.config) + current_enenrgy_consume[robot_id] = consume_energy / self.config.max_uav_energy + new_energy = robot.energy - consume_energy + self.robot_energy_timelist[self.current_timestep + 1][robot_id] = new_energy + + if is_collide or (new_robot_px < 0 or new_robot_px > self.nlon or new_robot_py < 0 or new_robot_py > self.nlat): + new_robot_position[robot_id][0] = robot.px + new_robot_position[robot_id][1] = robot.py + self.robot_x_timelist[self.current_timestep + 1][robot_id] = robot.px + self.robot_y_timelist[self.current_timestep + 1][robot_id] = robot.py + robot.set(robot.px, robot.py, robot_theta, energy=new_energy) + else: + new_robot_position[robot_id][0] = new_robot_px + new_robot_position[robot_id][1] = new_robot_py + self.robot_x_timelist[self.current_timestep + 1][robot_id] = new_robot_px + self.robot_y_timelist[self.current_timestep + 1][robot_id] = new_robot_py + robot.set(new_robot_px, new_robot_py, robot_theta, energy=new_energy) + + selected_data, selected_next_data = get_human_position_list( + self.current_timestep + 1, self.human_df, self.config + ) + human_transmit_data_list = np.zeros_like(self.cur_data_amount_timelist) # 0 means no update + for human_id, human in enumerate(self.humans): + next_px, next_py, next_theta = get_human_position_from_list( + self.current_timestep + 1, human_id, selected_data, selected_next_data, self.config + ) + should_reset = judge_aoi_update([next_px, next_py], new_robot_position, self.config) + if self.env_mode == 'easy': + if should_reset: + # if the human is in the range of the robot, then part of human's data will be transmitted + if human.aoi > 1: + human_transmit_data_list[human_id] = human.aoi + else: + human_transmit_data_list[human_id] = 1 + + human.set(next_px, next_py, next_theta, aoi=0, data_amount=0) + num_updated_human += 1 + else: + # if the human is not in the range of the robot, then update the aoi of the human + human_transmit_data_list[human_id] = 0 + new_aoi = human.aoi + 1 + human.set(next_px, next_py, next_theta, aoi=new_aoi, data_amount=human.aoi) + + elif self.env_mode == 'hard': + if should_reset: + # if the human is in the range of the robot, then part of human's data will be transmitted + last_data_amount = human.data_amount + human.update(next_px, next_py, next_theta, transmitted_data=self.transmit_v) + human_transmit_data_list[human_id] = min(last_data_amount + human.collect_v, self.transmit_v) + num_updated_human += 1 + else: + # if the human is not in the range of the robot, then no data will be transmitted, \ + # and update aoi and caculate new collected data amount + human_transmit_data_list[human_id] = 0 + human.update(next_px, next_py, next_theta, transmitted_data=0) + else: + raise ValueError("env_mode should be 'easy' or 'hard'") + + self.cur_data_amount_timelist[human_id] = human.data_amount + self.current_human_aoi_list[human_id] = human.aoi + self.sync_human_df(human_id, self.current_timestep + 1, human.aoi, human.data_amount) + self.data_collection_distribution[human_id] += human.collect_v + self.data_transmission_distribution[human_id] += human_transmit_data_list[human_id] + + self.mean_aoi_timelist[self.current_timestep + 1] = np.mean(self.current_human_aoi_list) + self.update_human_timelist[self.current_timestep] = num_updated_human + delta_sum_transmit_data = np.sum(human_transmit_data_list) + self.data_transmission += (delta_sum_transmit_data * 0.3) # Mb, 0.02M/s per person + if self.env_mode == 'easy': + # in easy mode, the data amount generated per step is equal to the number of humans + self.total_generated_data_amount = self.num_timestep * self.human_num + elif self.env_mode == 'hard': + # in hard mode, the data amount generated per step is equal to the sum of the data amount of all humans + self.total_generated_data_amount += self.generate_data_amount_per_step + + # TODO: need to be well-defined + aoi_reward = self.mean_aoi_timelist[self.current_timestep] - self.mean_aoi_timelist[self.current_timestep + 1] + energy_reward = np.sum(current_enenrgy_consume) + reward = aoi_reward \ + - self.config.energy_factor * energy_reward + + # if hasattr(self.agent.policy, 'action_values'): + # self.action_values.append(self.agent.policy.action_values) + self.robot_actions.append(action) + self.rewards.append(reward) + self.aoi_rewards.append(aoi_reward) + self.energy_rewards.append(energy_reward) + distribution_entropy = entropy( + self.data_collection_distribution/ np.sum(self.data_collection_distribution), + self.data_transmission_distribution/np.sum(self.data_transmission_distribution) + 1e-10) + self.plot_states.append([[robot.get_obs() for robot in self.robots], + [human.get_obs() for human in self.humans]]) + + next_state = JointState([robot.get_obs() for robot in self.robots], [human.get_obs() for human in self.humans]) + + self.current_timestep += 1 + # print('This game is on',self.current_timestep,' step\n') + if self.current_timestep >= self.num_timestep: + done = True + else: + done = False + info = { + "performance_info": { + "mean_aoi": self.mean_aoi_timelist[self.current_timestep], + "mean_transmit_data": self.data_transmission / self.human_num, + "mean_energy_consumption": 1.0 - ( + np.mean(self.robot_energy_timelist[self.current_timestep]) / self.max_uav_energy), + "transmitted_data_ratio": self.data_transmission/(self.total_generated_data_amount*0.3), + "human_coverage": np.mean(self.update_human_timelist) / self.human_num, + "distribution_entropy": distribution_entropy # 增加交叉熵信息 + }, + } + + return next_state, reward, done, info + + def render(self): + """ + Overview: + Render the environment to an image. The image is generated using the matplotlib library, and it includes the \ + historical trajectories of the robots, the current positions of the robots, the sensing range of the robots, the \ + positions of the humans, and their AoI values. The image is returned as a numpy array. + Returns: + - image (:obj:`np.ndarray`): The rendered image of the environment. + """ + import matplotlib.pyplot as plt + import matplotlib.patches as patches + import io + import imageio + + map_max_x = self.config.nlon + map_max_y = self.config.nlat + # Create a new figure + fig, ax = plt.subplots(figsize=(8, 6)) + plt.subplots_adjust(right=0.75) # 给数据留白 + + # Plot the historical trajectories of the robots + for timestep in range(len(self.robot_x_timelist)): + for robot_id in range(len(self.robot_x_timelist[timestep])): + ax.plot( + self.robot_x_timelist[timestep][robot_id], + self.robot_y_timelist[timestep][robot_id], + color='gray', + alpha=0.5 + ) + + # Plot the current positions of the robots + for robot in self.robots: + ax.plot(robot.px, robot.py, marker='o', markersize=5, color='blue') + + # Plot the sensing range of the robots + for robot in self.robots: + robot_x, robot_y = robot.px, robot.py + circle = patches.Circle((robot_x, robot_y), self.config.sensing_range, edgecolor='blue', facecolor='none') + ax.add_patch(circle) + + # Plot the positions of the humans and their AOI values + for human in self.humans: + human_x, human_y, aoi = human.px, human.py, human.aoi + ax.plot(human_x, human_y, marker='x', markersize=5, color='red') + ax.text(human_x, human_y, str(aoi), fontsize=8, color='black') + + # Set the title and axis labels + # ax.set_xlim(0, map_max_x) + # ax.set_ylim(0, map_max_y) + ax.set_xlabel('X') + ax.set_ylabel('Y') + + # show reward/aoi_reward/energy_reward/mean_aoi/energy in the upper right corner + reward_text = f"Reward: {self.rewards[-1] if self.rewards else 0:.2f}\n" \ + f"AOI Reward: {self.aoi_rewards[-1] if self.aoi_rewards else 0:.2f}\n" \ + f"Energy Reward: {self.energy_rewards[-1] if self.energy_rewards else 0:.2f}\n" \ + f"Mean AOI: {self.mean_aoi_timelist[self.current_timestep] if self.current_timestep < len(self.mean_aoi_timelist) else 0:.2f}\n" \ + f"Energy: {np.mean(self.robot_energy_timelist[self.current_timestep]) if self.current_timestep < len(self.robot_energy_timelist) else 0:.2f}" + plt.text(1.05, 0.95, reward_text, horizontalalignment='left', verticalalignment='top', + transform=ax.transAxes, fontsize=10, bbox=dict(facecolor='white', alpha=0.6), + clip_on=False) # Ensure text is not clipped + # Leave some blank space outside of the map + ax.margins(x=0.1, y=0.1) + ax.set_title('Crowd Simulation Visualization') + + # Render the figure to an image + fig.canvas.draw() + image = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) + image = image.reshape(fig.canvas.get_width_height()[::-1] + (3, )) + plt.close() + + return image diff --git a/zoo/CrowdSim/envs/Crowdsim/env/crowd_sim_base_config.py b/zoo/CrowdSim/envs/Crowdsim/env/crowd_sim_base_config.py new file mode 100644 index 000000000..008fd7e16 --- /dev/null +++ b/zoo/CrowdSim/envs/Crowdsim/env/crowd_sim_base_config.py @@ -0,0 +1,156 @@ +from easydict import EasyDict + +# define base config +base_config = EasyDict( + { + "num_timestep": 120, # 120x15=1800s=30min + "step_time": 15, # seconds per step + "max_uav_energy": 359640, # 359640 J <-- 359.64 kJ (4500mAh, 22.2V) DJI Matrice + "rotation_limit": 360, + "diameter_of_human_blockers": 0.5, # meters + "h_rx": 1.3, # meters, height of RX + "h_b": 1.7, # meters, height of a human blocker + "velocity": 18, + "frequence_band": 28, # GHz + "h_d": 120, # meters, height of drone-BS + "alpha_nlos": 113.63, + "beta_nlos": 1.16, + "zeta_nlos": 2.58, # Frequency 28GHz, sub-urban. channel model + "alpha_los": 84.64, + "beta_los": 1.55, + "zeta_los": 0.12, + "g_tx": 0, # dB + "g_rx": 5, # dB + "tallest_locs": None, # obstacle + "no_fly_zone": None, # obstacle + "start_timestamp": 1519894800, + "end_timestamp": 1519896600, + "energy_factor": 3, # TODO: energy factor in reward function + "robot_num": 2, + "rollout_num": 1, # 1 2 6 12 15, calculated based on robot_num + } +) + +# define all dataset configs +dataset_configs = { + 'purdue': EasyDict( + { + "lower_left": [-86.93, 40.4203], # longitude and latitude + "upper_right": [-86.9103, 40.4313], + "nlon": 200, + "nlat": 120, + "human_num": 59, + "dataset_dir": '/home/nighoodRen/CrowdSim/CrowdSim/envs/crowd_sim/dataset/purdue/59 users.csv', + "sensing_range": 23.2, # unit 23.2 + "one_uav_action_space": [ + [0, 0], [30, 0], [-30, 0], [0, 30], [0, -30], [21, 21], [21, -21], [-21, 21], [-21, -21] + ], + "max_x_distance": 1667, # meters + "max_y_distance": 1222, # meters + "density_of_human_blockers": 30000 / 1667 / 1222, # blockers/m2 + } + ), + 'ncsu': EasyDict( + { + "lower_left": [-78.6988, 35.7651], # longitude and latitude + "upper_right": [-78.6628, 35.7896], + "nlon": 3600, + "nlat": 2450, + "human_num": 33, + "dataset_dir": '/home/nighoodRen/CrowdSim/CrowdSim/envs/crowd_sim/dataset/NCSU/33 users.csv', + "sensing_range": 220, # unit 220 + "one_uav_action_space": [ + [0, 0], [300, 0], [-300, 0], [0, 300], [0, -300], [210, 210], [210, -210], [-210, 210], [-210, -210] + ], + "max_x_distance": 3255.4913305859623, # meters + "max_y_distance": 2718.3945272795013, # meters + "density_of_human_blockers": 30000 / 3255.4913305859623 / 2718.3945272795013, # blockers/m2 + } + ), + 'kaist': EasyDict( + { + "lower_left": [127.3475, 36.3597], # longitude and latitude + "upper_right": [127.3709, 36.3793], + "nlon": 2340, + "nlat": 1960, + "human_num": 92, + "dataset_dir": '/home/nighoodRen/CrowdSim/CrowdSim/envs/crowd_sim/dataset/KAIST/92 users.csv', + "sensing_range": 220, # unit 220 + "one_uav_action_space": [ + [0, 0], [300, 0], [-300, 0], [0, 300], [0, -300], [210, 210], [210, -210], [-210, 210], [-210, -210] + ], + "max_x_distance": 2100.207579392558, # meters + "max_y_distance": 2174.930950809533, # meters + "density_of_human_blockers": 30000 / 2100.207579392558 / 2174.930950809533, # blockers/m2 + } + ), + # ... could add more dataset configs here +} + + +# get config according to data set name +def get_selected_config(data_set_name): + if data_set_name in dataset_configs: + dataset_config = dataset_configs[data_set_name] + return EasyDict({**base_config, **dataset_config}) + else: + raise ValueError(f"Data set '{data_set_name}' not found.") + + +# r:meters, 2D distance +# threshold: dB +def try_sensing_range(r, data_set_name): + import math + config = get_selected_config(data_set_name) + p_los = math.exp( + -config.density_of_human_blockers * config.diameter_of_human_blockers * r * (config.h_b - config.h_rx) / + (config.h_d - config.h_rx) + ) + p_nlos = 1 - p_los + PL_los = config.alpha_los + config.beta_los * 10 * math.log10( + math.sqrt(r * r + config.h_d * config.h_d) + ) + config.zeta_los + PL_nlos = config.alpha_nlos + config.beta_nlos * 10 * math.log10( + math.sqrt(r * r + config.h_d * config.h_d) + ) + config.zeta_nlos + PL = p_los * PL_los + p_nlos * PL_nlos + CL = PL - config.g_tx - config.g_rx + print(p_los, p_nlos) + print(CL) + + +# Maximum Coupling Loss (110dB is recommended) +# purdue: + +# 123dB -> 560m -> 60.5 range +# 121dB -> 420m -> 45.4 range +# 119dB -> 300m -> 32.4 range +# 117dB -> 215m -> 23.2 range √ +# 115dB -> 140m -> 15 range + +# ncsu: +# 123dB -> 600m -> 600 range +# 121dB -> 435m -> 435 range +# 119dB -> 315m -> 315 range +# 117dB -> 220m -> 220 range √ +# 115dB -> 145m -> 145 range + +# kaist: +# 123dB -> 600m -> 600 range +# 121dB -> 435m -> 435 range +# 119dB -> 315m -> 315 range +# 117dB -> 220m -> 220 range √ +# 115dB -> 145m -> 145 range + +# san: +# 123dB -> 600m -> 600 range +# 121dB -> 450m -> 450 range +# 119dB -> 330m -> 330 range +# 117dB -> 240m -> 240 range √ +# 115dB -> 165m -> 165 range + +if __name__ == '__main__': + # example usage + data_set_name = 'purdue' + selected_config = get_selected_config(data_set_name) + print(selected_config) diff --git a/zoo/CrowdSim/envs/Crowdsim/env/model/__init__.py b/zoo/CrowdSim/envs/Crowdsim/env/model/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/zoo/CrowdSim/envs/Crowdsim/env/model/agent.py b/zoo/CrowdSim/envs/Crowdsim/env/model/agent.py new file mode 100644 index 000000000..5c1adef36 --- /dev/null +++ b/zoo/CrowdSim/envs/Crowdsim/env/model/agent.py @@ -0,0 +1,194 @@ +import abc +import random +import logging +from zoo.CrowdSim.envs.Crowdsim.env.model.mdp import * + + +class Agent(): + + def __init__(self): + """ + Base class for robot and human. Have the physical attributes of an agent. + + """ + self.policy = None + + def print_info(self): + logging.info('Agent is visible and has "holonomic" kinematic constraint') + + def set_policy(self, policy): + self.policy = policy + + def act(self, state, current_timestep): + if self.policy is None: + raise AttributeError('Policy attribute has to be set!') + action = self.policy.predict(state, current_timestep) + return action + + +class Human(): + """ + Overview: + Human class. Have the physical attributes of a human agent. The human agent has a data queue to store the \ + information blocks. The data queue is updated when the human agent moves and transmits data to the robot. \ + The age of information (aoi) is calculated based on the data queue. + Interface: + `__init__`, `set`, `update`, `get_obs`. + """ + + # collect_v_prob = {1: 1, 2: 0} + def __init__(self, id, config): + self.id = id + self.config = config + self.px = None + self.py = None + self.theta = None + self.aoi = 0 + self.data_queue = InformationQueue() + self.data_amount = 0 + self.collect_v_prob = getattr(self.config, 'collect_v_prob', {1: 1, 2: 0}) + self.collect_v = random.choices(list(map(int, self.collect_v_prob.keys())), list(self.collect_v_prob.values()))[0] + + def set(self, px, py, theta, aoi, data_amount): + """ + Overview: + Set the physical attributes of the human agent. + Arguments: + - px (:obj:`float`): The x-coordinate of the human agent. + - py (:obj:`float`): The y-coordinate of the human agent. + - theta (:obj:`float`): The orientation of the human agent. + - aoi (:obj:`float`): The age of information (aoi) of the human agent. + - data_amount (:obj:`int`): The amount of data blocks in the data queue of the human agent. + """ + self.px = px + self.py = py + self.theta = theta + self.aoi = aoi + self.data_amount = data_amount + + def update(self, px, py, theta, transmitted_data): + """ + Overview: + Update the physical attributes of the human agent and the data queue. The age of information (aoi) is \ + calculated based on the data queue. + Arguments: + - px (:obj:`float`): The x-coordinate of the human agent. + - py (:obj:`float`): The y-coordinate of the human agent. + - theta (:obj:`float`): The orientation of the human agent. + - transmitted_data (:obj:`int`): The number of data blocks transmitted to the robot. + """ + self.px = px # position + self.py = py + self.theta = theta + self.data_queue.update(self.collect_v, transmitted_data) + self.aoi = self.data_queue.total_aoi() + self.data_amount = self.data_queue.total_blocks() + + def get_obs(self): + """ + Overview: + Get the observation of the human agent. The observation includes the position, age of information (aoi), \ + and the amount of data blocks in the data queue. + Returns: + - obs (:obj:`HumanState`): The observation of the human agent. + """ + # obs: (px, py, remaining_data, aoi) + return HumanState( + self.px / self.config.nlon, self.py / self.config.nlat, self.data_amount / self.config.num_timestep, + self.aoi / self.config.num_timestep + ) + + +class Robot(): + """ + Overview: + Robot class. Have the physical attributes of a robot agent. + Interface: + `__init__`, `set`, `get_obs`. + """ + + def __init__(self, id, config): + self.id = id + self.config = config + self.px = None # position + self.py = None + self.theta = None + self.energy = None + + def set(self, px, py, theta, energy): + """ + Overview: + Set the physical attributes of the robot agent. + Arguments: + - px (:obj:`float`): The x-coordinate of the robot agent. + - py (:obj:`float`): The y-coordinate of the robot agent. + - theta (:obj:`float`): The orientation of the robot agent. + - energy (:obj:`float`): The remaining energy of the robot agent. + """ + self.px = px # position + self.py = py + self.theta = theta + self.energy = energy + + def get_obs(self): + """ + Overview: + Get the observation of the robot agent. The observation includes the position, orientation, and the remaining \ + energy of the robot agent. + Returns: + - obs (:obj:`RobotState`): The observation of the robot agent. + """ + return RobotState( + self.px / self.config.nlon, self.py / self.config.nlat, self.theta / self.config.rotation_limit, + self.energy / self.config.max_uav_energy + ) + + +class InformationQueue: + """ + Overview: + Information queue class. The data queue is updated when the human agent moves and transmits data to the robot. \ + The age of information (aoi) is calculated based on the data queue. + + Interface: + `__init__`, `update`, `total_aoi`, `total_blocks`. + """ + + def __init__(self): + # Initialize the queue to hold the age of each information block + self.queue = [] + + def update(self, arrivals, departures): + """ + Overview: + Update the data queue. Increase the age of information (aoi) for each block in the queue. Add new information \ + blocks with aoi of 0. Remove the specified number of oldest information blocks. + Arguments: + - arrivals (:obj:`int`): The number of new information blocks entering the queue. + - departures (:obj:`int`): The number of oldest information blocks leaving the queue. + """ + # Increase the age of information (aoi) for each block in the queue + self.queue = [age + 1 for age in self.queue] + + # Add new information blocks with aoi of 0 + self.queue.extend([0] * arrivals) + + # Remove the specified number of oldest information blocks + self.queue = self.queue[departures:] if departures <= len(self.queue) else [] + + def total_aoi(self): + # Return the total age of information in the queue + return sum(self.queue) + + def total_blocks(self): + # Return the total number of information blocks in the queue + return len(self.queue) + + +# # Example of using the InformationQueue class +# info_queue = InformationQueue() +# info_queue.update(arrivals=5, departures=0) # 5 blocks enter the queue, all with aoi of 0 +# info_queue.update(arrivals=3, departures=2) # 3 new blocks enter, 2 blocks leave +# total_age = info_queue.total_aoi() # Calculate the total age of information in the queue + +# total_age diff --git a/zoo/CrowdSim/envs/Crowdsim/env/model/mdp.py b/zoo/CrowdSim/envs/Crowdsim/env/model/mdp.py new file mode 100644 index 000000000..9d4d2b770 --- /dev/null +++ b/zoo/CrowdSim/envs/Crowdsim/env/model/mdp.py @@ -0,0 +1,94 @@ +from collections import namedtuple +from itertools import product +import torch +import numpy as np + + +# State +class HumanState(object): + + def __init__(self, px, py, theta, aoi): + self.px = px + self.py = py + self.theta = theta + self.aoi = aoi + self.position = (self.px, self.py) + + def __add__(self, other): + return other + (self.px, self.py, self.theta, self.aoi) + + def __str__(self): + return ' '.join([str(x) for x in [self.px, self.py, self.theta, self.aoi]]) + + def to_tuple(self): + return self.px, self.py, self.theta, self.aoi + + +class RobotState(object): + + def __init__(self, px, py, theta, energy): + self.px = px + self.py = py + self.theta = theta + self.energy = energy + + self.position = (self.px, self.py) + + def __add__(self, other): + return other + (self.px, self.py, self.theta, self.energy) + + def __str__(self): + return ' '.join([str(x) for x in [self.px, self.py, self.theta, self.energy]]) + + def to_tuple(self): + return self.px, self.py, self.theta, self.energy + + +class JointState(object): + + def __init__(self, robot_states, human_states): + for robot_state in robot_states: + assert isinstance(robot_state, RobotState) + for human_state in human_states: + assert isinstance(human_state, HumanState) + + self.robot_states = robot_states + self.human_states = human_states + + def to_tensor(self, add_batch_size=False, device=None): + robot_states_tensor = torch.tensor( + [robot_state.to_tuple() for robot_state in self.robot_states], dtype=torch.float32 + ) + human_states_tensor = torch.tensor( + [human_state.to_tuple() for human_state in self.human_states], dtype=torch.float32 + ) + + if add_batch_size: # True + robot_states_tensor = robot_states_tensor.unsqueeze(0) + human_states_tensor = human_states_tensor.unsqueeze(0) + + if device is not None: + robot_states_tensor = robot_states_tensor.to(device) + human_states_tensor = human_states_tensor.to(device) + + return robot_states_tensor, human_states_tensor + + def to_array(self): + robot_states_array = np.array([robot_state.to_tuple() for robot_state in self.robot_states]) + human_states_array = np.array([human_state.to_tuple() for human_state in self.human_states]) + + return robot_states_array, human_states_array + + +def build_action_space(config): + robot_num = config.robot_num + + # dx, dy + one_uav_action_space = config.one_uav_action_space + action_space = list(product(one_uav_action_space, repeat=robot_num)) + + return np.array(action_space) + + +if __name__ == "__main__": + print(build_action_space()) diff --git a/zoo/CrowdSim/envs/Crowdsim/env/model/utils.py b/zoo/CrowdSim/envs/Crowdsim/env/model/utils.py new file mode 100644 index 000000000..421b9508a --- /dev/null +++ b/zoo/CrowdSim/envs/Crowdsim/env/model/utils.py @@ -0,0 +1,349 @@ +import numpy as np + +np.seterr(invalid='ignore') + +from zoo.CrowdSim.envs.Crowdsim.env.model.agent import * +from zoo.CrowdSim.envs.Crowdsim.env.model.mdp import JointState +from shapely.geometry import * + + +def tensor_to_joint_state(state, config): + """ + Overview: + Convert the state tensor to the JointState object. The state tensor is a tuple of two tensors, the first one \ + is the robot state tensor, and the second one is the human state tensor. The robot state tensor is a tensor of \ + shape (1, robot_num, 4), and the human state tensor is a tensor of shape (1, human_num, 4). + Arguments: + - state (:obj:`tuple`): The state tensor. + - config (:obj:`dict`): The configuration of the environment. + Returns: + - joint_state (:obj:`JointState`): The JointState object. + """ + robot_states, human_states = state + + robot_states = robot_states.cpu().squeeze(0).data.numpy() + robot_states = [ + RobotState( + robot_state[0] * config.nlon, robot_state[1] * config.nlat, robot_state[2] * config.rotation_limit, + robot_state[3] * config.max_uav_energy + ) for robot_state in robot_states + ] + + human_states = human_states.cpu().squeeze(0).data.numpy() + human_states = [ + HumanState( + human_state[0] * config.nlon, human_state[1] * config.nlat, human_state[2] * config.rotation_limit, + human_state[3] * config.num_timestep + ) for human_state in human_states + ] + + return JointState(robot_states, human_states) + + +def tensor_to_robot_states(robot_state_tensor, config): + """ + Overview: + Convert the robot state tensor to a list of RobotState objects. The robot state tensor is a tensor of shape \ + (1, robot_num, 4). + Arguments: + - robot_state_tensor (:obj:`torch.Tensor`): The robot state tensor. + - config (:obj:`dict`): The configuration of the environment. + Returns: + - robot_states (:obj:`list`): The list of RobotState objects. + """ + robot_states = robot_state_tensor.cpu().squeeze(0).data.numpy() + robot_states = [ + RobotState( + robot_state[0] * config.nlon, robot_state[1] * config.nlat, robot_state[2] * config.rotation_limit, + robot_state[3] * config.max_uav_energy + ) for robot_state in robot_states + ] + return robot_states + + +def get_human_position_list(selected_timestep, human_df, config): + """ + Overview: + Get the human position list at the selected timestep. The human position list is a list of tuples, each tuple \ + contains the x, y, and theta of a human. + Arguments: + - selected_timestep (:obj:`int`): The selected timestep. + - human_df (:obj:`pandas.DataFrame`): The human dataframe. + - config (:obj:`dict`): The configuration of the environment. + Returns: + - human_position_list (:obj:`list`): The human position list. + """ + # config.step_time means the time interval between two timesteps + selected_timestamp = config.start_timestamp + selected_timestep * config.step_time + selected_data = human_df[human_df.timestamp == selected_timestamp] + selected_data = selected_data.set_index("id") + + if selected_timestep < config.num_timestep: + selected_next_data = human_df[human_df.timestamp == (selected_timestamp + config.step_time)] + selected_next_data = selected_next_data.set_index("id") + else: + selected_next_data = None + + return selected_data, selected_next_data + + +def get_human_position_from_list(selected_timestep, human_id, selected_data, selected_next_data, config): + """ + Overview: + Get the human position from the human position list at the selected timestep. The human position is a tuple \ + containing the x, y, and theta of the human. + Arguments: + - selected_timestep (:obj:`int`): The selected timestep. + - human_id (:obj:`int`): The human id. + - selected_data (:obj:`pandas.DataFrame`): The human position list at the selected timestep. + - selected_next_data (:obj:`pandas.DataFrame`): The human position list at the next timestep. + - config (:obj:`dict`): The configuration of the environment. + Returns: + - px (:obj:`float`): The x coordinate of the human. + - py (:obj:`float`): The y coordinate of the human. + - theta (:obj:`float`): The orientation of the human. + """ + px, py = selected_data.loc[human_id, ["x", "y"]] + + if selected_timestep < config.num_timestep: + npx, npy = selected_next_data.loc[human_id, ["x", "y"]] + theta = get_theta(0, 0, npx - px, npy - py) + # print(px, py, npx, npy, theta) + else: + theta = 0 + + return px, py, theta + + +def judge_aoi_update(human_position, robot_position, config): + """ + Overview: + Judge whether the AoI should be updated, i.e., the human is in the sensing range of the robot. + Args: + - human_position (:obj:`list`): The position of the human. + - robot_position (:obj:`list`): The position of the robot. + - config (:obj:`dict`): The configuration of the environment. + Returns: + - should_update (:obj:`bool`): Whether the AoI should be updated. + """ + should_reset = False + for robot_id in range(config.robot_num): + unit_distance = np.sqrt( + np.power(robot_position[robot_id][0] - human_position[0], 2) + + np.power(robot_position[robot_id][1] - human_position[1], 2) + ) + if unit_distance <= config.sensing_range: + should_reset = True + break + + return should_reset + + +def inPoly(polygon, x, y): + """ + Overview: + Judge whether a point is in a polygon. + Arguments: + - polygon (:obj:`list`): The polygon. + - x (:obj:`float`): The x coordinate of the point. + - y (:obj:`float`): The y coordinate of the point. + Returns: + - in_poly (:obj:`bool`): Whether the point is in the polygon. + """ + pt = (x, y) + line = LineString(polygon) + point = Point(pt) + polygon = Polygon(line) + return polygon.contains(point) + + +def iscrosses(line1, line2): + """ + Overview: + Judge whether two lines cross each other. + Arguments: + - line1 (:obj:`list`): The first line. + - line2 (:obj:`list`): The second line. + Returns: + - crosses (:obj:`bool`): Whether the two lines cross each other. + """ + if LineString(line1).crosses(LineString(line2)): + return True + return False + + +def crossPoly(square, x1, y1, x2, y2): + """ + Overview: + Judge whether a line crosses a polygon. + Arguments: + - square (:obj:`list`): The polygon. + - x1 (:obj:`float`): The x coordinate of the start point of the line. + - y1 (:obj:`float`): The y coordinate of the start point of the line. + - x2 (:obj:`float`): The x coordinate of the end point of the line. + - y2 (:obj:`float`): The y coordinate of the end point of the line. + Returns: + - crosses (:obj:`bool`): Whether the line crosses the polygon. + """ + our_line = LineString([[x1, y1], [x2, y2]]) + line1 = LineString([square[0], square[2]]) + line2 = LineString([square[1], square[3]]) + if our_line.crosses(line1) or our_line.crosses(line2): + return True + else: + return False + + +def judge_collision(new_robot_px, new_robot_py, old_robot_px, old_robot_py, config): + """ + Overview: + Judge whether a collision happens. A collision happens when the new position of the robot is in the no-fly zone. + Arguments: + - new_robot_px (:obj:`float`): The x coordinate of the new position of the robot. + - new_robot_py (:obj:`float`): The y coordinate of the new position of the robot. + - old_robot_px (:obj:`float`): The x coordinate of the old position of the robot. + - old_robot_py (:obj:`float`): The y coordinate of the old position of the robot. + - config (:obj:`dict`): The configuration of the environment. + Returns: + - collision (:obj:`bool`): Whether a collision happens. + """ + if config.no_fly_zone is None: + return False + + for square in config.no_fly_zone: + if inPoly(square, new_robot_px, new_robot_py): + return True + if crossPoly(square, new_robot_px, new_robot_py, old_robot_px, old_robot_py): + return True + return False + + +def get_theta(x1, y1, x2, y2): + ang1 = np.arctan2(y1, x1) + ang2 = np.arctan2(y2, x2) + theta = np.rad2deg((ang1 - ang2) % (2 * np.pi)) + return theta + + +def consume_uav_energy(fly_time, hover_time, config): + """ + Overview: + Calculate the energy consumption of the UAV. The energy consumption is calculated based on the power consumption \ + of the UAV in the flying state and the hovering state. + Arguments: + - fly_time (:obj:`float`): The flying time. + - hover_time (:obj:`float`): The hovering time. + - config (:obj:`dict`): The configuration of the environment. + Returns: + - energy (:obj:`float`): The energy consumption of the UAV. + """ + # configs + Pu = 0.5 # the average transmitted power of each user, W, e.g. mobile phone + P0 = 79.8563 # blade profile power, W + P1 = 88.6279 # derived power, W + U_tips = 120 # tip speed of the rotor blade of the UAV,m/s + v0 = 4.03 # the mean rotor induced velocity in the hovering state,m/s + d0 = 0.6 # fuselage drag ratio + rho = 1.225 # density of air,kg/m^3 + s0 = 0.05 # the rotor solidity + A = 0.503 # the area of the rotor disk, m^2 + Vt = config.velocity # velocity of the UAV,m/s + + Power_flying = P0 * (1 + 3 * Vt ** 2 / U_tips ** 2) + \ + P1 * np.sqrt((np.sqrt(1 + Vt ** 4 / (4 * v0 ** 4)) - Vt ** 2 / (2 * v0 ** 2))) + \ + 0.5 * d0 * rho * s0 * A * Vt ** 3 + + Power_hovering = P0 + P1 + + return fly_time * Power_flying + hover_time * Power_hovering + + +def get_border(ur, lf): + upper_left = [lf[0], ur[1]] + upper_right = [ur[0], ur[1]] + lower_right = [ur[0], lf[1]] + lower_left = [lf[0], lf[1]] + + coordinates = [upper_left, upper_right, lower_right, lower_left, upper_left] + + geo_json = { + "type": "FeatureCollection", + "properties": { + "lower_left": lower_left, + "upper_right": upper_right + }, + "features": [] + } + + grid_feature = { + "type": "Feature", + "geometry": { + "type": "Polygon", + "coordinates": [coordinates], + } + } + + geo_json["features"].append(grid_feature) + + return geo_json + + +def traj_to_timestamped_geojson(index, trajectory, robot_num, color): + point_gdf = trajectory.df.copy() + point_gdf["previous_geometry"] = point_gdf["geometry"].shift() + point_gdf["time"] = point_gdf.index + point_gdf["previous_time"] = point_gdf["time"].shift() + + features = [] + + # for Point in GeoJSON type + for _, row in point_gdf.iterrows(): + corrent_point_coordinates = [row["geometry"].xy[0][0], row["geometry"].xy[1][0]] + current_time = [row["time"].isoformat()] + + if index < robot_num: + radius = 8 # 125(5 units) + opacity = 0.05 + popup_html = f'

UAV {int(row["id"])}

' + f'

raw coord: {corrent_point_coordinates}

' \ + + f'

grid coord: ({row["x"]},{row["y"]})

' \ + + f'

dist coord: ({row["x_distance"]}m, {row["y_distance"]}m)

' \ + + f'

energy: {row["energy"]}J

' + else: + radius = 2 + opacity = 1 + popup_html = f'

Human {int(row["id"])}

' + f'

raw coord: {corrent_point_coordinates}

' \ + + f'

grid coord: ({row["x"]},{row["y"]})

' \ + + f'

dist coord: ({row["x_distance"]}m, {row["y_distance"]}m)

' \ + + f'

aoi: {int(row["aoi"])}

' + + # for Point in GeoJSON type (Temporally Deprecated) + features.append( + { + "type": "Feature", + "geometry": { + "type": "Point", + "coordinates": corrent_point_coordinates, + }, + "properties": { + "times": current_time, + 'popup': popup_html, + "icon": 'circle', # point + "iconstyle": { + 'fillColor': color, + 'fillOpacity': opacity, # 透明度 + 'stroke': 'true', + 'radius': radius, + 'weight': 1, + }, + "style": { # line + "color": color, + }, + "code": 11, + }, + } + ) + return features + + +if __name__ == "__main__": + print(judge_collision(new_robot_px=6505, new_robot_py=5130, old_robot_px=6925, old_robot_py=5130)) diff --git a/zoo/CrowdSim/envs/__init__.py b/zoo/CrowdSim/envs/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/zoo/CrowdSim/envs/crowdsim_lightzero_env.py b/zoo/CrowdSim/envs/crowdsim_lightzero_env.py new file mode 100644 index 000000000..bf20aa9c6 --- /dev/null +++ b/zoo/CrowdSim/envs/crowdsim_lightzero_env.py @@ -0,0 +1,155 @@ +from typing import Union, Optional + +import gym +import numpy as np +from itertools import product +import logging + +from ding.envs import BaseEnv, BaseEnvTimestep +from ding.envs import ObsPlusPrevActRewWrapper +from ding.torch_utils import to_ndarray +from ding.utils import ENV_REGISTRY + +import zoo.CrowdSim.envs.Crowdsim.env + + +@ENV_REGISTRY.register('crowdsim_lightzero') +class CrowdSimEnv(BaseEnv): + + def __init__(self, cfg: dict = {}) -> None: + self._cfg = cfg + self._init_flag = False + self._replay_path = cfg.get('replay_path', None) + self._robot_num = self._cfg.robot_num + self._human_num = self._cfg.human_num + self._observation_space = gym.spaces.Dict( + { + 'robot_state': gym.spaces.Box( + low=float("-inf"), high=float("inf"), shape=(self._robot_num, 4), dtype=np.float32 + ), + 'human_state': gym.spaces.Box( + low=float("-inf"), high=float("inf"), shape=(self._human_num, 4), dtype=np.float32 + ) + } + ) + # action space + # one_uav_action_space = [[0, 0], [30, 0], [-30, 0], [0, 30], [0, -30]] + self.real_action_space = list(product(self._cfg.one_uav_action_space, repeat=self._robot_num)) + one_uav_action_n = len(self._cfg.one_uav_action_space) + self._action_space = gym.spaces.Discrete(one_uav_action_n ** self._robot_num) + self._action_space.seed(0) # default seed + self._reward_space = gym.spaces.Box(low=0.0, high=1.0, shape=(1, ), dtype=np.float32) + self._continuous = False + # obs_mode 'dict': {'robot_state': robot_state, 'human_state': human_state} + # obs_mode '2-dim-array': np.concatenate((robot_state, human_state), axis=0) + # obs_mode '1-dim-array': np.concatenate((robot_state, human_state), axis=0).flatten() + self.obs_mode = self._cfg.get('obs_mode', '2-dim-array') + assert self.obs_mode in [ + 'dict', '2-dim-array', '1-dim-array' + ], "obs_mode should be 'dict' or '2-dim-array' or '1-dim-array'!" + # action_mode 'combine': combine all robot actions into one action, action space size = one_uav_action_n**robot_num + # action_mode 'separate': separate robot actions, shape (robot_num,), for each robot action space size = one_uav_action_n + self.action_mode = self._cfg.get('action_mode', 'combine') + assert self.action_mode in ['combine', 'separate'], "action_mode should be 'combine' or 'separate'!" + + def reset(self) -> np.ndarray: + if not self._init_flag: + self._env = gym.make('CrowdSim-v0', dataset=self._cfg.dataset, custom_config=self._cfg) + self._init_flag = True + if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed: + np_seed = 100 * np.random.randint(1, 1000) + self._env.seed(self._seed + np_seed) + self._action_space.seed(self._seed + np_seed) + elif hasattr(self, '_seed'): + self._env.seed(self._seed) + self._action_space.seed(self._seed) + self._eval_episode_return = 0 + # process obs + raw_obs = self._env.reset() + obs_list = list(raw_obs.to_tensor()) + if self.obs_mode == 'dict': + obs = {'robot_state': obs_list[0], 'human_state': obs_list[1]} + elif self.obs_mode == '2-dim-array': + # robot_state: (robot_num, 4), human_state: (human_num, 4) + obs = np.concatenate((obs_list[0], obs_list[1]), axis=0) + elif self.obs_mode == '1-dim-array': + obs = np.concatenate((obs_list[0], obs_list[1]), axis=0).flatten() + action_mask = np.ones(self.action_space.n, 'int8') + obs = {'observation': obs, 'action_mask': action_mask, 'to_play': -1} + if self._replay_path is not None: + self._frame = [] + + return obs + + def close(self) -> None: + if self._init_flag: + self._env.close() + self._init_flag = False + + def seed(self, seed: int, dynamic_seed: bool = True) -> None: + self._seed = seed + self._dynamic_seed = dynamic_seed + np.random.seed(self._seed) + + def step(self, action: Union[int, np.ndarray]) -> BaseEnvTimestep: + if self.action_mode == 'combine': + if isinstance(action, np.ndarray) and action.shape == (1, ): + action = action.squeeze() + real_action = self.real_action_space[action] + elif self.action_mode == 'separate': + assert isinstance(action, np.ndarray) and action.shape == (self._robot_num, ), "illegal action!" + real_action = tuple([self._cfg.one_uav_action_space[action[i]] for i in range(self._robot_num)]) + assert isinstance(real_action, tuple) and len(real_action) == self._robot_num, "illegal action!" + raw_obs, rew, done, info = self._env.step(real_action) + obs_list = list(raw_obs.to_array()) + if self.obs_mode == 'dict': + obs = {'robot_state': obs_list[0], 'human_state': obs_list[1]} + elif self.obs_mode == '2-dim-array': + # robot_state: (robot_num, 4), human_state: (human_num, 4) + obs = np.concatenate((obs_list[0], obs_list[1]), axis=0) + elif self.obs_mode == '1-dim-array': + obs = np.concatenate((obs_list[0], obs_list[1]), axis=0).flatten() + + self._eval_episode_return += rew + if done: + info['eval_episode_return'] = self._eval_episode_return + # logging.INFO('one game finish!') + + action_mask = np.ones(self.action_space.n, 'int8') + obs = {'observation': obs, 'action_mask': action_mask, 'to_play': -1} + rew = to_ndarray([rew]).astype(np.float32) + if self._replay_path is not None: + self._frame.append(self._env.render()) + if done: + import imageio, os + if not os.path.exists(self._replay_path): + os.makedirs(self._replay_path) + imageio.mimsave(self._replay_path + '/replay.gif', self._frame) + # save env.human_df as csv + self._env.human_df.to_csv(self._replay_path + '/human_df.csv') + return BaseEnvTimestep(obs, rew, done, info) + + def enable_save_replay(self, replay_path: Optional[str] = None) -> None: + if replay_path is None: + replay_path = './video' + self._replay_path = replay_path + + def random_action(self) -> np.ndarray: + random_action = self.action_space.sample() + random_action = to_ndarray([random_action], dtype=np.int64) + return random_action + + @property + def observation_space(self) -> gym.spaces.Space: + return self._observation_space + + @property + def action_space(self) -> gym.spaces.Space: + return self._action_space + + @property + def reward_space(self) -> gym.spaces.Space: + return self._reward_space + + def __repr__(self) -> str: + return "LightZero CrowdSim Env" diff --git a/zoo/CrowdSim/envs/test_CrowdSim_env.py b/zoo/CrowdSim/envs/test_CrowdSim_env.py new file mode 100644 index 000000000..81ac5f7ab --- /dev/null +++ b/zoo/CrowdSim/envs/test_CrowdSim_env.py @@ -0,0 +1,32 @@ +import numpy as np +from easydict import EasyDict +from zoo.CrowdSim.envs.CrowdSim_env import CrowdSimEnv + +mcfg = EasyDict( + env_name='CrowdSim-v0', + dataset='purdue', + robot_num=2, + human_num=59, # purdue + one_uav_action_space=[[0, 0], [30, 0], [-30, 0], [0, 30], [0, -30]] +) + + +def test_naive(cfg): + env = CrowdSimEnv(cfg) + env.seed(314) + assert env._seed == 314 + obs = env.reset() + assert obs['observation'].shape == (244, ) + for i in range(10): + random_action = env.random_action() + timestep = env.step(random_action) + print(timestep) + assert isinstance(timestep.obs['observation'], np.ndarray) + assert isinstance(timestep.done, bool) + assert timestep.obs['observation'].shape == (244, ) + assert timestep.reward.shape == (1, ) + print(env.observation_space, env.action_space, env.reward_space) + env.close() + + +test_naive(mcfg) diff --git a/zoo/CrowdSim/envs/test_crowdsim_lightzero_env.py b/zoo/CrowdSim/envs/test_crowdsim_lightzero_env.py new file mode 100644 index 000000000..46eb15b35 --- /dev/null +++ b/zoo/CrowdSim/envs/test_crowdsim_lightzero_env.py @@ -0,0 +1,86 @@ +import numpy as np +import pytest +from easydict import EasyDict +from zoo.CrowdSim.envs.crowdsim_lightzero_env import CrowdSimEnv + +mcfg=EasyDict( + env_name='CrowdSim-v0', + dataset = 'purdue', + robot_num = 2, + human_num = 59, # purdue + one_uav_action_space = [[0, 0], [30, 0], [-30, 0], [0, 30], [0, -30]], + obs_mode = '2-dim-array', + env_mode = 'hard', + transmit_v=120, + collect_v_prob = {'1': 1, '2': 0}, + ) + +@ pytest.mark.envtest + +class TestCrowdSimEnv: + + def test_obs_dict(self): + mcfg['obs_mode'] = 'dict' + env = CrowdSimEnv(mcfg) + env.seed(314) + assert env._seed == 314 + obs = env.reset() + assert isinstance(obs['observation'], dict) + assert obs['observation']['robot_state'].shape == (2, 4) + assert obs['observation']['human_state'].shape == (59, 4) + for i in range(10): + random_action = env.random_action() + timestep = env.step(random_action) + print(timestep) + assert isinstance(timestep.obs['observation'], dict) + assert timestep.obs['observation']['robot_state'].shape == (2, 4) + assert timestep.obs['observation']['human_state'].shape == (59, 4) + assert isinstance(timestep.done, bool) + assert timestep.reward.shape == (1, ) + print(env.observation_space, env.action_space, env.reward_space) + env.close() + + def test_obs_2_dim_array(self): + mcfg['obs_mode'] = '2-dim-array' + env = CrowdSimEnv(mcfg) + env.seed(314) + assert env._seed == 314 + obs = env.reset() + assert obs['observation'].shape == (61, 4) + for i in range(10): + random_action = env.random_action() + timestep = env.step(random_action) + print(timestep) + assert timestep.obs['observation'].shape == (61, 4) + assert isinstance(timestep.done, bool) + assert timestep.reward.shape == (1, ) + print(env.observation_space, env.action_space, env.reward_space) + env.close() + + def test_obs_1_dim_array(self): + mcfg['obs_mode'] = '1-dim-array' + env = CrowdSimEnv(mcfg) + env.seed(314) + env.enable_save_replay('/home/nighoodRen/LightZero/result/test_replay') + assert env._seed == 314 + obs = env.reset() + assert obs['observation'].shape == (244, ) + while True: + random_action = env.random_action() + timestep = env.step(random_action) + print(timestep) + assert timestep.obs['observation'].shape == (244, ) + assert isinstance(timestep.done, bool) + assert timestep.reward.shape == (1, ) + if timestep.done: + break + print(env.observation_space, env.action_space, env.reward_space) + print('episode reward:', timestep.info['eval_episode_return']) + env.close() + + +if __name__ == '__main__': + test = TestCrowdSimEnv() + # test.test_obs_dict() + # test.test_obs_2_dim_array() + test.test_obs_1_dim_array()