diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1e772ecc..fa23c283 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,11 +1,12 @@ repos: - repo: https://github.com/asottile/seed-isort-config - rev: v1.9.4 + rev: v2.2.0 hooks: - id: seed-isort-config args: [--exclude=^((examples|docs)/.*)$] - repo: https://github.com/timothycrosley/isort + rev: 5.4.2 hooks: - id: isort @@ -14,7 +15,7 @@ repos: rev: 20.8b1 hooks: - id: black - language_version: python3.7 + language_version: python3 - repo: https://gitlab.com/pycqa/flake8 rev: 3.8.3 diff --git a/genrl/agents/deep/a2c/a2c.py b/genrl/agents/deep/a2c/a2c.py index 633b005c..7dacdaf7 100644 --- a/genrl/agents/deep/a2c/a2c.py +++ b/genrl/agents/deep/a2c/a2c.py @@ -70,10 +70,12 @@ def _create_model(self) -> None: state_dim, action_dim, discrete, action_lim = get_env_properties( self.env, self.network ) + if isinstance(self.network, str): arch_type = self.network if self.shared_layers is not None: arch_type += "s" + self.ac = get_model("ac", arch_type)( state_dim, action_dim, diff --git a/genrl/agents/deep/ddpg/ddpg.py b/genrl/agents/deep/ddpg/ddpg.py index cf73aed1..a34eea69 100644 --- a/genrl/agents/deep/ddpg/ddpg.py +++ b/genrl/agents/deep/ddpg/ddpg.py @@ -66,6 +66,7 @@ def _create_model(self) -> None: arch_type = self.network if self.shared_layers is not None: arch_type += "s" + self.ac = get_model("ac", arch_type)( state_dim, action_dim, @@ -75,6 +76,18 @@ def _create_model(self) -> None: "Qsa", False, ).to(self.device) + elif isinstance(self.network, str) and self.shared_layers is not None: + arch_type = self.network + "s" + self.ac = get_model("ac", arch_type)( + state_dim, + action_dim, + critic_prev=self.critic_prev, + actor_prev=self.actor_prev, + shared_layers=self.shared_layers, + critic_post=self.value_layers, + actor_post=self.policy_layers, + val_type="Qsa", + ).to(self.device) else: self.ac = self.network diff --git a/genrl/agents/deep/ppo1/ppo1.py b/genrl/agents/deep/ppo1/ppo1.py index 97b52306..daef962d 100644 --- a/genrl/agents/deep/ppo1/ppo1.py +++ b/genrl/agents/deep/ppo1/ppo1.py @@ -85,6 +85,17 @@ def _create_model(self): action_lim=action_lim, activation=self.activation, ).to(self.device) + elif isinstance(self.network, str) and self.shared_layers is not None: + arch_type = self.network + "s" + self.ac = get_model("ac", arch_type)( + state_dim, + action_dim, + critic_prev=self.critic_prev, + actor_prev=self.actor_prev, + shared_layers=self.shared_layers, + critic_post=self.value_layers, + actor_post=self.policy_layers, + ).to(self.device) else: self.ac = self.network.to(self.device) diff --git a/genrl/core/buffers.py b/genrl/core/buffers.py index 0a5b6e7c..3ba9e7e1 100644 --- a/genrl/core/buffers.py +++ b/genrl/core/buffers.py @@ -117,15 +117,15 @@ def sample( ] ): """ - (Returns randomly sampled memories from replay memory along with their + (Returns randomly sampled memories from replay memory along with their respective indices and weights) - :param batch_size: Number of samples per batch - :param beta: (Bias exponent used to correct + :param batch_size: Number of samples per batch + :param beta: (Bias exponent used to correct Importance Sampling (IS) weights) - :type batch_size: int - :type beta: float - :returns: (Tuple containing `states`, `actions`, `next_states`, + :type batch_size: int + :type beta: float + :returns: (Tuple containing `states`, `actions`, `next_states`, `rewards`, `dones`, `indices` and `weights`) """ if beta is None: @@ -181,3 +181,118 @@ def __len__(self) -> int: @property def pos(self): return len(self.buffer) + + +class MultiAgentReplayBuffer: + """ + Implements the basic Experience Replay Mechanism for MultiAgents + by feeding in global states, global actions, global rewards, + global next_states, global dones + + :param capacity: Size of the replay buffer + :type capacity: int + :param num_agents: Number of agents in the environment + :type num_agents: int + """ + + def __init__(self, num_agents: int, capacity: int): + """ + Initialising the buffer + :param num_agents: number of agents in the environment + :type num_agents: int + :param capacity: Max buffer size + :type capacity: int + + """ + self.capacity = capacity + self.num_agents = num_agents + self.buffer = deque(maxlen=self.capacity) + + def push(self, inp: Tuple) -> None: + """ + Adds new experience to buffer + + :param inp: (Tuple containing `state`, `action`, `reward`, + `next_state` and `done`) + :type inp: tuple + :returns: None + """ + self.buffer.append(inp) + + def sample(self, batch_size): + + """ + Returns randomly sampled experiences from replay memory + + :param batch_size: Number of samples per batch + :type batch_size: int + :returns: (Tuple composing of `indiv_obs_batch`, + `indiv_action_batch`, `indiv_reward_batch`, `indiv_next_obs_batch`, + `global_state_batch`, `global_actions_batch`, `global_next_state_batch`, + `done_batch`) + """ + indiv_obs_batch = [ + [] for _ in range(self.num_agents) + ] # [ [states of agent 1], ... ,[states of agent n] ] ] + indiv_action_batch = [ + [] for _ in range(self.num_agents) + ] # [ [actions of agent 1], ... , [actions of agent n]] + indiv_reward_batch = [[] for _ in range(self.num_agents)] + indiv_next_obs_batch = [[] for _ in range(self.num_agents)] + + global_state_batch = [] + global_next_state_batch = [] + global_actions_batch = [] + done_batch = [] + + batch = random.sample(self.buffer, batch_size) + + for experience in batch: + state, action, reward, next_state, done = experience + + for i in range(self.num_agents): + indiv_obs_batch[i].append(state[i]) + indiv_action_batch[i].append(action[i]) + indiv_reward_batch[i].append(reward[i]) + indiv_next_obs_batch[i].append(next_state[i]) + + global_state_batch.append(torch.cat(state)) + global_actions_batch.append(torch.cat(action)) + global_next_state_batch.append(torch.cat(next_state)) + done_batch.append(done) + + global_state_batch = torch.stack(global_state_batch) + global_actions_batch = torch.stack(global_actions_batch) + global_next_state_batch = torch.stack(global_next_state_batch) + done_batch = torch.stack(done_batch) + indiv_obs_batch = torch.stack( + [torch.FloatTensor(obs) for obs in indiv_obs_batch] + ) + indiv_action_batch = torch.stack( + [torch.FloatTensor(act) for act in indiv_action_batch] + ) + indiv_reward_batch = torch.stack( + [torch.FloatTensor(rew) for rew in indiv_reward_batch] + ) + indiv_next_obs_batch = torch.stack( + [torch.FloatTensor(next_obs) for next_obs in indiv_next_obs_batch] + ) + + return ( + indiv_obs_batch, + indiv_action_batch, + indiv_reward_batch, + indiv_next_obs_batch, + global_state_batch, + global_actions_batch, + global_next_state_batch, + done_batch, + ) + + def __len__(self): + """ + Gives number of experiences in buffer currently + + :returns: Length of replay memory + """ + return len(self.buffer) diff --git a/genrl/core/rollout_storage.py b/genrl/core/rollout_storage.py index 48679be4..6fcad2c6 100644 --- a/genrl/core/rollout_storage.py +++ b/genrl/core/rollout_storage.py @@ -257,3 +257,202 @@ def _get_samples(self, batch_inds: np.ndarray) -> RolloutBufferSamples: self.returns[batch_inds].flatten(), ) return RolloutBufferSamples(*tuple(map(self.to_torch, data))) + + +class MultiAgentRolloutBuffer(BaseBuffer): + """ + Rollout buffer used in on-policy algorithms like MAA2C/MAA3C. + :param num_agents: (int) Max number of agents in the environment + :param buffer_size: (int) Max number of element in the buffer + :param env: (Environment) The environment being trained on + :param device: (torch.device) + :param gae_lambda: (float) Factor for trade-off of bias vs variance for Generalized Advantage Estimator + Equivalent to classic advantage when set to 1. + :param gamma: (float) Discount factor + :param n_envs: (int) Number of parallel environments + """ + + def __init__( + self, + num_agents: int, + buffer_size: int, + env, + device: Union[torch.device, str] = "cpu", + gae_lambda: float = 1, + gamma: float = 0.99, + ): + super(MultiAgentRolloutBuffer, self).__init__(buffer_size, env, device) + + self.buffer_size = buffer_size + self.num_agents = num_agents + self.env = env + self.device = device + self.gae_lambda = gae_lambda + self.gamma = gamma + + self.observations, self.actions, self.rewards, self.advantages = ( + [None] * self.num_agents, + [None] * self.num_agents, + [None] * self.num_agents, + [None] * self.num_agents, + ) + self.returns, self.dones, self.values, self.log_probs = ( + [None] * self.num_agents, + [None] * self.num_agents, + [None] * self.num_agents, + [None] * self.num_agents, + ) + self.generator_ready = False + self.reset() + + def reset(self) -> None: + self.observations = torch.zeros( + *(self.buffer_size, self.env.n_envs, self.num_agents, *self.env.obs_shape) + ) + self.actions = torch.zeros( + *( + self.buffer_size, + self.env.n_envs, + self.num_agents, + *self.env.action_shape, + ) + ) + self.rewards = torch.zeros(self.buffer_size, self.env.n_envs, self.num_agents) + self.returns = torch.zeros(self.buffer_size, self.env.n_envs, self.num_agents) + self.dones = torch.zeros(self.buffer_size, self.env.n_envs, self.num_agents) + self.values = torch.zeros(self.buffer_size, self.env.n_envs, self.num_agents) + self.log_probs = torch.zeros(self.buffer_size, self.env.n_envs, self.num_agents) + self.advantages = torch.zeros( + self.buffer_size, self.env.n_envs, self.num_agents + ) + self.generator_ready = False + super(MultiAgentRolloutBuffer, self).reset() + + def add( + self, + obs: torch.zeros, + action: torch.zeros, + reward: torch.zeros, + done: torch.zeros, + value: torch.Tensor, + log_prob: torch.Tensor, + ) -> None: + """ + :param obs: (torch.zeros) Observation + :param action: (torch.zeros) Action + :param reward: (torch.zeros) + :param done: (torch.zeros) End of episode signal. + :param value: (torch.Tensor) estimated value of the current state + following the current policy. + :param log_prob: (torch.Tensor) log probability of the action + following the current policy. + """ + if len(log_prob.shape) == 0: + # Reshape 0-d tensor to avoid error + log_prob = log_prob.reshape(-1, 1) + + self.observations[self.pos] = obs.detach().clone() + self.actions[self.pos] = action.squeeze().detach().clone() + self.rewards[self.pos] = reward.detach().clone() + self.dones[self.pos] = done.detach().clone() + self.values[self.pos] = ( + value.detach().clone().flatten().reshape(-1, self.num_agents) + ) + self.log_probs[self.pos] = ( + log_prob.detach().clone().flatten().reshape(-1, self.num_agents) + ) + self.pos += 1 + + if self.pos == self.buffer_size: + self.full = True + + def get( + self, batch_size: Optional[int] = None + ) -> Generator[RolloutBufferSamples, None, None]: + assert self.full, "" + indices = np.random.permutation(self.buffer_size * self.env.n_envs) + # Prepare the data + if not self.generator_ready: + for tensor in [ + "observations", + "actions", + "values", + "log_probs", + "advantages", + "returns", + ]: + self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) + self.generator_ready = True + + # Return everything, don't create minibatches + if batch_size is None: + batch_size = self.buffer_size * self.env.n_envs + + start_idx = 0 + while start_idx < self.buffer_size * self.env.n_envs: + yield self._get_samples(indices[start_idx : start_idx + batch_size]) + start_idx += batch_size + + def _get_samples(self, batch_inds: np.ndarray) -> RolloutBufferSamples: + data = ( + self.observations[batch_inds], + self.actions[batch_inds], + self.values[batch_inds].flatten().reshape(-1, self.num_agents), + self.log_probs[batch_inds].flatten().reshape(-1, self.num_agents), + self.advantages[batch_inds].flatten().reshape(-1, self.num_agents), + self.returns[batch_inds].flatten().reshape(-1, self.num_agents), + ) + return RolloutBufferSamples(*tuple(map(self.to_torch, data))) + + def compute_returns_and_advantage( + self, last_value: torch.Tensor, dones: torch.zeros, use_gae: bool = False + ) -> None: + """ + Post-processing step: compute the returns (sum of discounted rewards) + and advantage (A(s) = R - V(S)). + Adapted from Stable-Baselines PPO2. + :param last_value: (torch.Tensor) + :param dones: (torch.zeros) + :param use_gae: (bool) Whether to use Generalized Advantage Estimation + or normal advantage for advantage computation. + """ + last_value = last_value.flatten().reshape(-1, self.num_agents) + + if use_gae: + last_gae_lam = 0 + for step in reversed(range(self.buffer_size)): + if step == self.buffer_size - 1: + next_non_terminal = 1.0 - dones + next_value = last_value + else: + next_non_terminal = 1.0 - self.dones[step + 1] + next_value = self.values[step + 1] + delta = ( + self.rewards[step] + + self.gamma * next_value * next_non_terminal + - self.values[step] + ) + last_gae_lam = ( + delta + + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam + ) + self.advantages[step] = last_gae_lam + self.returns = self.advantages + self.values + else: + # Discounted return with value bootstrap + # Note: this is equivalent to GAE computation + # with gae_lambda = 1.0 + last_return = 0.0 + for step in reversed(range(self.buffer_size)): + if step == self.buffer_size - 1: + next_non_terminal = 1.0 - dones + next_value = last_value + last_return = self.rewards[step] + next_non_terminal * next_value + else: + next_non_terminal = 1.0 - self.dones[step + 1] + last_return = ( + self.rewards[step] + + self.gamma * last_return * next_non_terminal + ) + self.returns[step] = last_return + self.advantages = self.returns - self.values diff --git a/genrl/environments/gym_wrapper.py b/genrl/environments/gym_wrapper.py index ceb99472..cbbc96e4 100644 --- a/genrl/environments/gym_wrapper.py +++ b/genrl/environments/gym_wrapper.py @@ -106,3 +106,107 @@ def close(self) -> None: Closes environment """ self.env.close() + + +class MultiGymWrapper(gym.Wrapper): + """ + Wrapper class for all MultiAgent Particle Environments + + :param env: Gym environment name + :param n_envs: Number of environments. None if not vectorised + :param parallel: If vectorised, should environments be run through \ +serially or parallelly + :type env: string + :type n_envs: None, int + :type parallel: boolean + """ + + def __init__(self, env: gym.Env): + super(MultiGymWrapper, self).__init__(env) + self.env = env + + self.observation_space = self.env.observation_space + self.action_space = self.env.action_space + + self.state = None + self.action = None + self.reward = None + self.done = False + self.info = {} + + def __getattr__(self, name: str) -> Any: + """ + All other calls would go to base env + """ + env = super(MultiGymWrapper, self).__getattribute__("env") + return getattr(env, name) + + @property + def obs_shape(self): + if isinstance(self.env.observation_space, gym.spaces.Discrete): + obs_shape = (1,) + elif isinstance(self.env.observation_space, gym.spaces.Box): + obs_shape = self.env.observation_space.shape + return obs_shape + + @property + def action_shape(self): + if isinstance(self.env.action_space, gym.spaces.Box): + action_shape = self.env.action_space.shape + elif isinstance(self.env.action_space, gym.spaces.Discrete): + action_shape = (1,) + return action_shape + + def sample(self) -> np.ndarray: + """ + Shortcut method to directly sample from environment's action space + + :returns: Random action from action space + :rtype: NumPy Array + """ + return self.env.action_space.sample() + + def render(self, mode: str = "human") -> None: + """ + Renders all envs in a tiles format similar to baselines. + + :param mode: Can either be 'human' or 'rgb_array'. \ +Displays tiled images in 'human' and returns tiled images in 'rgb_array' + :type mode: string + """ + self.env.render(mode=mode) + + def seed(self, seed: int = None) -> None: + """ + Set environment seed + + :param seed: Value of seed + :type seed: int + """ + self.env.seed(seed) + + def step(self, action: np.ndarray) -> np.ndarray: + """ + Steps the env through given action + + :param action: Action taken by agent + :type action: NumPy array + :returns: Next observation, reward, game status and debugging info + """ + self.state, self.reward, self.done, self.info = self.env.step(action) + self.action = action + return self.state, self.reward, self.done, self.info + + def reset(self) -> np.ndarray: + """ + Resets environment + + :returns: Initial state + """ + return self.env.reset() + + def close(self) -> None: + """ + Closes environment + """ + self.env.close() diff --git a/genrl/environments/suite.py b/genrl/environments/suite.py index 11310c08..3e69ca40 100644 --- a/genrl/environments/suite.py +++ b/genrl/environments/suite.py @@ -97,3 +97,38 @@ def AtariEnv( env = wrapper(env) return env + + +def MultiAgentParticleEnv(scenario_name: str, benchmark: bool) -> gym.Env: + """ + Function to apply wrappers for all Atari envs by Trainer class + + :param scenarion_name: Environment Name + :type env: string + :param benchmark: laod benchmark results + :type wrapper_list: bool + :returns: Gym Atari Environment + :rtype: object + """ + import multiagent.scenarios as scenarios + from multiagent.environment import MultiAgentEnv + + # load scenario from script + scenario = scenarios.load(scenario_name + ".py").Scenario() + # create world + world = scenario.make_world() + # create multiagent environment + if benchmark: + env = MultiAgentEnv( + world, + scenario.reset_world, + scenario.reward, + scenario.observation, + scenario.benchmark_data, + ) + else: + env = MultiAgentEnv( + world, scenario.reset_world, scenario.reward, scenario.observation + ) + + return env diff --git a/genrl/utils/utils.py b/genrl/utils/utils.py index 89e53337..7e6c849c 100644 --- a/genrl/utils/utils.py +++ b/genrl/utils/utils.py @@ -13,13 +13,13 @@ def get_model(type_: str, name_: str) -> Union: """ - Utility to get the class of required function + Utility to get the class of required function - :param type_: "ac" for Actor Critic, "v" for Value, "p" for Policy - :param name_: Name of the specific structure of model. ( + :param type_: "ac" for Actor Critic, "v" for Value, "p" for Policy + :param name_: Name of the specific structure of model. ( Eg. "mlp" or "cnn") - :type type_: string - :returns: Required class. Eg. MlpActorCritic + :type type_: string + :returns: Required class. Eg. MlpActorCritic """ if type_ == "ac": from genrl.core import get_actor_critic_from_name @@ -42,13 +42,13 @@ def mlp( sac: bool = False, ): """ - Generates an MLP model given sizes of each layer + Generates an MLP model given sizes of each layer - :param sizes: Sizes of hidden layers - :param sac: True if Soft Actor Critic is being used, else False - :type sizes: tuple or list - :type sac: bool - :returns: (Neural Network with fully-connected linear layers and + :param sizes: Sizes of hidden layers + :param sac: True if Soft Actor Critic is being used, else False + :type sizes: tuple or list + :type sac: bool + :returns: (Neural Network with fully-connected linear layers and activation layers) """ layers = [] @@ -63,6 +63,166 @@ def mlp( return nn.Sequential(*layers) +# If at all you need to concatenate states to actions after passing states through n FC layers +def mlp_concat( + layer_sizes: Tuple, + weight_init: str = "xavier_uniform", + activation_func: str = "relu", + concat_ind: int = -1, # negative number means no concatenation + sac: bool = False, +): + """ + Generates an MLP model given sizes of each layer + + :param layer_sizes: Sizes of hidden layers + :param weight_init: type of weight initialization + :param activation_func: type of activation function + :param concat_ind: index of layer at which actions to be concatenated + :param sac: True if Soft Actor Critic is being used, else False + :type layer_sizes: tuple or list + :type concat_ind: int + :type sac: bool + :type weight_init,activation_func: string + :returns: (Neural Network with fully-connected linear layers and + activation layers) + """ + layers = [] + limit = len(layer_sizes) if sac is False else len(layer_sizes) - 1 + + # add more activations + activation = nn.Tanh() if activation_func == "tanh" else nn.ReLU() + + # add more weight init + if weight_init == "xavier_uniform": + weight_init = torch.nn.init.xavier_uniform_ + elif weight_init == "xavier_normal": + weight_init = torch.nn.init.xavier_normal_ + + for layer in range(limit - 1): + if layer == concat_ind: + continue + act = activation if layer < limit - 2 else nn.Identity() + layers += [nn.Linear(layer_sizes[layer], layer_sizes[layer + 1])] + weight_init(layers[-1].weight) + layers += [act] + + return nn.Sequential(*layers) + + +def shared_mlp( + network1_prev: Tuple, + network2_prev: Tuple, + shared_layers: Tuple, + network1_post: Tuple, + network2_post: Tuple, + weight_init: str = "xavier_uniform", + activation_func: str = "relu", + sac: bool = False, +): + """ + Generates an MLP model given sizes of each layer (Mostly used for SharedActorCritic) + + :param network1_prev: Sizes of network1's initial layers + :param network2_prev: Sizes of network2's initial layers + :param shared: Sizes of shared layers + :param network1_post: Sizes of network1's latter layers + :param network2_post: Sizes of network2's latter layers + :param weight_init: type of weight initialization + :param activation_func: type of activation function + :param sac: True if Soft Actor Critic is being used, else False + :type network1_prev,network2_prev,shared,network1_post,network2_post: tuple or list + :type weight_init,activation_func: string + :type sac: bool + :returns: network1 and networ2(Neural Network with fully-connected linear layers and + activation layers) + """ + + if len(network1_prev) != 0: + net1_prev = nn.ModuleList() + if len(network2_prev) != 0: + net2_prev = nn.ModuleList() + if len(shared_layers) != 0: + shared = nn.ModuleList() + if len(network1_post) != 0: + net1_post = nn.ModuleList() + if len(network2_post) != 0: + net2_post = nn.ModuleList() + + # add more weight init + if weight_init == "xavier_uniform": + weight_init = torch.nn.init.xavier_uniform_ + elif weight_init == "xavier_normal": + weight_init = torch.nn.init.xavier_normal_ + else: + weight_init = None + + if activation_func == "relu": + activation = nn.ReLU() + elif activation_func == "tanh": + activation = nn.Tanh() + else: + activation = None + + if len(shared_layers) != 0 or len(network1_post) != 0 or len(network2_post) != 0: + if len(network1_prev) != 0 and len(network2_prev) != 0: + if not ( + network1_prev[-1] == network2_prev[-1] + and network1_prev[-1] == shared_layers[0] + and network1_post[0] == network2_post[0] + and network1_post[0] == shared_layers[-1] + ): + raise ValueError + else: + if not ( + network1_post[0] == network2_post[0] + and network1_post[0] == shared_layers[-1] + ): + raise ValueError + + for i in range(len(network1_prev) - 1): + net1_prev.append(nn.Linear(network1_prev[i], network1_prev[i + 1])) + if weight_init is not None: + weight_init(net1_prev[-1].weight) + if activation is not None: + net1_prev.append(activation) + + for i in range(len(network2_prev) - 1): + net2_prev.append(nn.Linear(network2_prev[i], network2_prev[i + 1])) + if weight_init is not None: + weight_init(net2_prev[-1].weight) + if activation is not None: + net2_prev.append(activation) + + for i in range(len(shared_layers) - 1): + shared.append(nn.Linear(shared_layers[i], shared_layers[i + 1])) + if weight_init is not None: + weight_init(shared[-1].weight) + if activation is not None: + shared.append(activation) + + for i in range(len(network1_post) - 1): + net1_post.append(nn.Linear(network1_post[i], network1_post[i + 1])) + if weight_init is not None: + weight_init(net1_post[-1].weight) + if activation is not None: + net1_post.append(activation) + + for i in range(len(network2_post) - 1): + net2_post.append(nn.Linear(network2_post[i], network2_post[i + 1])) + if weight_init is not None: + weight_init(net2_post[-1].weight) + if activation is not None: + net2_post.append(activation) + + if len(network1_prev) != 0 and len(network2_prev) != 0: + network1 = nn.Sequential(*net1_prev, *shared, *net1_post) + network2 = nn.Sequential(*net2_prev, *shared, *net2_post) + else: + network1 = nn.Sequential(*shared, *net1_post) + network2 = nn.Sequential(*shared, *net2_post) + return network1, network2 + + def cnn( channels: Tuple = (4, 16, 32), kernel_sizes: Tuple = (8, 4), @@ -70,18 +230,18 @@ def cnn( **kwargs, ) -> (Tuple): """ - (Generates a CNN model given input dimensions, channels, kernel_sizes and + (Generates a CNN model given input dimensions, channels, kernel_sizes and strides) - :param channels: Input output channels before and after each convolution - :param kernel_sizes: Kernel sizes for each convolution - :param strides: Strides for each convolution - :param in_size: Input dimensions (assuming square input) - :type channels: tuple - :type kernel_sizes: tuple - :type strides: tuple - :type in_size: int - :returns: (Convolutional Neural Network with convolutional layers and + :param channels: Input output channels before and after each convolution + :param kernel_sizes: Kernel sizes for each convolution + :param strides: Strides for each convolution + :param in_size: Input dimensions (assuming square input) + :type channels: tuple + :type kernel_sizes: tuple + :type strides: tuple + :type in_size: int + :returns: (Convolutional Neural Network with convolutional layers and activation layers) """ @@ -107,12 +267,12 @@ def noisy_mlp(fc_layers: List[int], noisy_layers: List[int], activation="relu"): """Noisy MLP generating helper function Args: - fc_layers (:obj:`list` of :obj:`int`): List of fully connected layers - noisy_layers (:obj:`list` of :obj:`int`): :ist of noisy layers - activation (str): Activation function to be used. ["tanh", "relu"] + fc_layers (:obj:`list` of :obj:`int`): List of fully connected layers + noisy_layers (:obj:`list` of :obj:`int`): :ist of noisy layers + activation (str): Activation function to be used. ["tanh", "relu"] Returns: - Noisy MLP model + Noisy MLP model """ model = [] act = nn.Tanh if activation == "tanh" else nn.ReLU() @@ -134,15 +294,15 @@ def get_env_properties( env: Union[gym.Env, VecEnv], network: Union[str, Any] = "mlp" ) -> (Tuple[int]): """ - Finds important properties of environment + Finds important properties of environment - :param env: Environment that the agent is interacting with - :type env: Gym Environment - :param network: Type of network architecture, eg. "mlp", "cnn" - :type network: str - :returns: (State space dimensions, Action space dimensions, + :param env: Environment that the agent is interacting with + :type env: Gym Environment + :param network: Type of network architecture, eg. "mlp", "cnn" + :type network: str + :returns: (State space dimensions, Action space dimensions, discreteness of action space and action limit (highest action value) - :rtype: int, float, ...; int, float, ...; bool; int, float, ... + :rtype: int, float, ...; int, float, ...; bool; int, float, ... """ if network == "cnn": state_dim = env.framestack @@ -199,3 +359,21 @@ def safe_mean(log: Union[torch.Tensor, List[int]]): else: func = np.mean return func(log) + + +def onehot_from_logits(self, logits, eps=0.0): + # get best (according to current policy) actions in one-hot form + argmax_acs = (logits == logits.max(0, keepdim=True)[0]).float() + if eps == 0.0: + return argmax_acs + # get random actions in one-hot form + rand_acs = torch.eye(logits.shape[1])[ + [np.random.choice(range(logits.shape[1]), size=logits.shape[0])] + ] + # chooses between best and random actions using epsilon greedy + return torch.stack( + [ + argmax_acs[i] if r > eps else rand_acs[i] + for i, r in enumerate(torch.rand(logits.shape[0])) + ] + ) diff --git a/tests/test_deep/test_agents/test_a2c.py b/tests/test_deep/test_agents/test_a2c.py index f731f40f..7dc1917e 100644 --- a/tests/test_deep/test_agents/test_a2c.py +++ b/tests/test_deep/test_agents/test_a2c.py @@ -23,7 +23,9 @@ def test_a2c_cnn(): def test_a2c_shared(): env = VectorEnv("CartPole-v0", 1) + algo = A2C("mlp", env, shared_layers=(32, 32), rollout_size=128) + trainer = OnPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1) trainer.train() shutil.rmtree("./logs") diff --git a/tests/test_deep/test_common/test_utils.py b/tests/test_deep/test_common/test_utils.py index 11bf1292..85dc95ab 100644 --- a/tests/test_deep/test_common/test_utils.py +++ b/tests/test_deep/test_common/test_utils.py @@ -8,7 +8,15 @@ from genrl.core import CnnValue, MlpActorCritic, MlpPolicy, MlpValue from genrl.environments import VectorEnv from genrl.trainers import OnPolicyTrainer -from genrl.utils import cnn, get_env_properties, get_model, mlp, set_seeds +from genrl.utils.utils import ( + cnn, + get_env_properties, + get_model, + mlp, + mlp_concat, + set_seeds, + shared_mlp, +) class TestUtils: @@ -33,15 +41,41 @@ def test_mlp(self): sizes = [2, 3, 3, 2] mlp_nn = mlp(sizes) mlp_nn_sac = mlp(sizes, sac=True) + mlp_nn_concat = mlp_concat(sizes, concat_ind=1, sac=False) + mlp_nn_concat_sac = mlp_concat(sizes, concat_ind=1, sac=True) + shared_mlp_nn1, shared_mlp_nn2 = shared_mlp( + sizes, sizes, sizes, sizes, sizes, sac=False + ) + shared_mlp_nn1_sac, shared_mlp_nn2_sac = shared_mlp( + sizes, sizes, sizes, sizes, sizes, sac=True + ) assert len(mlp_nn) == 2 * (len(sizes) - 1) assert all(isinstance(mlp_nn[i], nn.Linear) for i in range(0, 5, 2)) + assert len(mlp_nn_concat) == 2 * (len(sizes) - 1) + assert all(isinstance(mlp_nn_concat[i], nn.Linear) for i in range(0, 5, 2)) assert len(mlp_nn_sac) == 2 * (len(sizes) - 2) assert all(isinstance(mlp_nn_sac[i], nn.Linear) for i in range(0, 4, 2)) + assert len(mlp_nn_concat_sac) == 2 * (len(sizes) - 2) + assert all(isinstance(mlp_nn_concat_sac[i], nn.Linear) for i in range(0, 4, 2)) + assert len(shared_mlp_nn1) == 2 * (len(sizes) - 1) * 3 + assert len(shared_mlp_nn2) == 2 * (len(sizes) - 1) * 3 + assert all(isinstance(shared_mlp_nn1[i], nn.Linear) for i in range(0, 8, 2)) + assert all(isinstance(shared_mlp_nn2[i], nn.Linear) for i in range(0, 8, 2)) + assert len(shared_mlp_nn1_sac) == 2 * (len(sizes) - 2) * 3 + assert all(isinstance(shared_mlp_nn1_sac[i], nn.Linear) for i in range(0, 4, 2)) + assert len(shared_mlp_nn2_sac) == 2 * (len(sizes) - 2) * 3 + assert all(isinstance(shared_mlp_nn2_sac[i], nn.Linear) for i in range(0, 4, 2)) inp = torch.randn((2,)) assert mlp_nn(inp).shape == (2,) + assert mlp_nn_concat(inp).shape == (2,) + assert shared_mlp_nn1(inp).shape == (2,) + assert shared_mlp_nn2(inp).shape == (2,) assert mlp_nn_sac(inp).shape == (3,) + assert mlp_nn_concat_sac(inp).shape == (3,) + assert shared_mlp_nn1_sac(inp).shape == (3,) + assert shared_mlp_nn2_sac(inp).shape == (3,) def test_cnn(self): """