From 256f705c568fb7ec1f61e35b04c8d5d83e61b3d3 Mon Sep 17 00:00:00 2001 From: veds12 Date: Thu, 22 Oct 2020 12:16:52 +0530 Subject: [PATCH] cleaned up multiagent --- genrl/agents/multiagent/base/offpolicy.py | 34 ++++ genrl/agents/multiagent/maddpg/__init__.py | 0 genrl/agents/multiagent/maddpg/maddpg.py | 133 +++++++++++++ genrl/core/__init__.py | 2 + genrl/core/buffers.py | 120 +++++++++++- genrl/core/rollout_storage.py | 205 ++++++++++++++++++++- genrl/utils/__init__.py | 3 + genrl/utils/pettingzoo_interface.py | 61 ++++++ genrl/utils/utils.py | 22 ++- 9 files changed, 564 insertions(+), 16 deletions(-) create mode 100644 genrl/agents/multiagent/base/offpolicy.py create mode 100644 genrl/agents/multiagent/maddpg/__init__.py create mode 100644 genrl/agents/multiagent/maddpg/maddpg.py create mode 100644 genrl/utils/pettingzoo_interface.py diff --git a/genrl/agents/multiagent/base/offpolicy.py b/genrl/agents/multiagent/base/offpolicy.py new file mode 100644 index 00000000..d755b6f5 --- /dev/null +++ b/genrl/agents/multiagent/base/offpolicy.py @@ -0,0 +1,34 @@ +import collections +from abc import ABC + +import torch +import torch.nn as nn +import torch.optim as opt + +from genrl.core import MultiAgentReplayBuffer +from genrl.utils import MutiAgentEnvInterface + + +class MultiAgentOffPolicy(ABC): + """Base class for multiagent algorithms with OffPolicy agents + + Attributes: + network (str): The network type of the Q-value function. + Supported types: ["cnn", "mlp"] + env (Environment): The environment that the agent is supposed to act on + agents (list) : A list of all the agents to be used + create_model (bool): Whether the model of the algo should be created when initialised + batch_size (int): Mini batch size for loading experiences + gamma (float): The discount factor for rewards + layers (:obj:`tuple` of :obj:`int`): Layers in the Neural Network + of the Q-value function + lr_policy (float): Learning rate for the policy/actor + lr_value (float): Learning rate for the Q-value function + replay_size (int): Capacity of the Replay Buffer + seed (int): Seed for randomness + render (bool): Should the env be rendered during training? + device (str): Hardware being used for training. Options: + ["cuda" -> GPU, "cpu" -> CPU] + """ + + raise NotImplementedError diff --git a/genrl/agents/multiagent/maddpg/__init__.py b/genrl/agents/multiagent/maddpg/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/genrl/agents/multiagent/maddpg/maddpg.py b/genrl/agents/multiagent/maddpg/maddpg.py new file mode 100644 index 00000000..c66bb1cc --- /dev/null +++ b/genrl/agents/multiagent/maddpg/maddpg.py @@ -0,0 +1,133 @@ +import torch + +from genrl.agents import DDPG +from genrl.utils import MultiAgentReplayBuffer, PettingZooInterface, get_model + + +class MADDPG(ABC): + """MultiAgent Controller using the MADDPG algorithm + + Attributes: + network (str): The network type of the Q-value function of the agents. + Supported types: ["mlp"] + batch_size (int): Mini batch size for loading experiences + gamma (float): The discount factor for rewards + layers (:obj:`tuple` of :obj:`int`): Layers in the Neural Network + of the Q-value function + shared_layers(:obj:`tuple` of :obj:`int`): Sizes of shared layers in Actor Critic if using + lr_policy (float): Learning rate for the policy/actor + lr_value (float): Learning rate for the critic + replay_size (int): Capacity of the Replay Buffer + polyak (float): Target model update parameter (1 for hard update) + env (Environment): The environment that the agent is supposed to act on + replay_size (int): Capacity of the Replay Buffer + render (bool): Should the env be rendered during training? + noise (:obj:`ActionNoise`): Action Noise function added to aid in exploration + noise_std (float): Standard deviation of the action noise distribution + seed (int): Seed for randomness + device (str): Hardware being used for training. Options: + ["cuda" -> GPU, "cpu" -> CPU] + """ + + def __init__( + self, + *args, + env, + replay_size: int = int(1e6), + render: bool = False, + noise: ActionNoise = None, + noise_std: float = 0.2, + warmup_steps=1000, + **kwargs, + ): + self.env = env + self.network = network + self.num_agents = self.env.num_agents + self.replay_buffer = MultiAgentReplayBuffer(self.num_agents, buffer_maxlen) + self.EnvInterface = PettingZooInterface(self.env, self.agents) + self.render = render + self.warmup_steps = warmup_steps + self.shared_layers = shared_layers + ac = self._create_model() + self.agents = [ + DDPG(ac, noise, noise_std, **kwargs) for agent in self.env.agents + ] + + def _create_model(self): + state_dim, action_dim, discrete, _ = self.EnvInterface.get_env_properties() + if discrete: + raise Exception( + "Discrete Environments not supported for {}.".format(__class__.__name__) + ) + model = get_models("ac", self.network)( + state_dim, action_dim, self.shared_layers, + ) + + def update(self, batch_size): + ( + obs_batch, + indiv_action_batch, + indiv_reward_batch, + next_obs_batch, + global_state_batch, + global_actions_batch, + global_next_state_batch, + done_batch, + ) = self.replay_buffer.sample(batch_size) + for i in range(self.num_agents): + obs_batch_i = obs_batch[i] + indiv_action_batch_i = indiv_action_batch[i] + indiv_reward_batch_i = indiv_reward_batch[i] + next_obs_batch_i = next_obs_batch[i] + next_global_actions = [] + ( + next_obs_batch_i, + indiv_next_action, + next_global_actions, + ) = self.EnvInterface.trainer(indiv_next_action) + next_global_actions = torch.cat( + [next_actions_i for next_actions_i in next_global_actions], 1 + ) + self.EnvInterface.update_agents( + indiv_reward_batch_i, + obs_batch_i, + global_state_batch, + global_actions_batch, + global_next_state_batch, + next_global_actions, + ) + + def train(self, max_episode, max_steps, batch_size): + episode_rewards = [] + for episode in range(max_episode): + states = self.env.reset() + episode_reward = 0 + step = -1 + for step in range(max_steps): + if self.render: + self.env.render(mode="human") + + step += 1 + actions = self.EnvInterface.get_actions(states, steps, warmup_steps) + next_states, rewards, dones, _ = self.env.step(actions) + rewards = self.EnvInterface.flatten(rewards) + episode_reward += np.mean(agent_rewards) + dones = self.EnvInterface.flatten(dones) + if all(dones) or step == max_steps - 1: + dones = [1 for _ in range(self.num_agents)] + self.replay_buffer.push( + states, actions, rewards, next_states, dones + ) + episode_rewards.append(episode_reward) + print( + f"Episode: {episode + 1} | Steps Taken: {step +1} | Reward {episode_reward}" + ) + break + else: + dones = [0 for _ in range(self.num_agents)] + self.replay_buffer.push( + states, actions, rewards, next_states, dones + ) + states = next_states + if len(self.replay_buffer) > batch_size: + self.update(batch_size) diff --git a/genrl/core/__init__.py b/genrl/core/__init__.py index 96824ac1..ca118516 100644 --- a/genrl/core/__init__.py +++ b/genrl/core/__init__.py @@ -5,6 +5,7 @@ from genrl.core.buffers import PrioritizedReplayBufferSamples # noqa from genrl.core.buffers import ReplayBuffer # noqa from genrl.core.buffers import ReplayBufferSamples # noqa +from genrl.core.buffers import MultiAgentReplayBuffer from genrl.core.noise import ActionNoise # noqa from genrl.core.noise import NoisyLinear # noqa from genrl.core.noise import NormalActionNoise # noqa @@ -16,6 +17,7 @@ get_policy_from_name, ) from genrl.core.rollout_storage import RolloutBuffer # noqa +from genrl.core.rollout_storage import MultiAgentRolloutBuffer # noqa from genrl.core.values import ( # noqa BaseValue, CnnCategoricalValue, diff --git a/genrl/core/buffers.py b/genrl/core/buffers.py index 0a5b6e7c..5fd3f526 100644 --- a/genrl/core/buffers.py +++ b/genrl/core/buffers.py @@ -146,15 +146,7 @@ def sample( return [ torch.as_tensor(v, dtype=torch.float32) - for v in [ - states, - actions, - rewards, - next_states, - dones, - indices, - weights, - ] + for v in [states, actions, rewards, next_states, dones, indices, weights,] ] def update_priorities(self, batch_indices: Tuple, batch_priorities: Tuple) -> None: @@ -181,3 +173,113 @@ def __len__(self) -> int: @property def pos(self): return len(self.buffer) + + +class MultiAgentReplayBuffer: + """ + Implements the basic Experience Replay Mechanism for MultiAgents + by feeding in global states, global actions, global rewards, + global next_states, global dones + :param capacity: Size of the replay buffer + :type capacity: int + :param num_agents: Number of agents in the environment + :type num_agents: int + """ + + def __init__(self, num_agents: int, capacity: int): + """ + Initialising the buffer + :param num_agents: number of agents in the environment + :type num_agents: int + :param capacity: Max buffer size + :type capacity: int + """ + self.capacity = capacity + self.num_agents = num_agents + self.buffer = deque(maxlen=self.capacity) + + def push(self, inp: Tuple) -> None: + """ + Adds new experience to buffer + :param inp: (Tuple containing `state`, `action`, `reward`, + `next_state` and `done`) + :type inp: tuple + :returns: None + """ + self.buffer.append(inp) + + def sample(self, batch_size): + + """ + Returns randomly sampled experiences from replay memory + :param batch_size: Number of samples per batch + :type batch_size: int + :returns: (Tuple composing of `indiv_obs_batch`, + `indiv_action_batch`, `indiv_reward_batch`, `indiv_next_obs_batch`, + `global_state_batch`, `global_actions_batch`, `global_next_state_batch`, + `done_batch`) + """ + indiv_obs_batch = [ + [] for _ in range(self.num_agents) + ] # [ [states of agent 1], ... ,[states of agent n] ] ] + indiv_action_batch = [ + [] for _ in range(self.num_agents) + ] # [ [actions of agent 1], ... , [actions of agent n]] + indiv_reward_batch = [[] for _ in range(self.num_agents)] + indiv_next_obs_batch = [[] for _ in range(self.num_agents)] + + global_state_batch = [] + global_next_state_batch = [] + global_actions_batch = [] + done_batch = [] + + batch = random.sample(self.buffer, batch_size) + + for experience in batch: + state, action, reward, next_state, done = experience + + for i in range(self.num_agents): + indiv_obs_batch[i].append(state[i]) + indiv_action_batch[i].append(action[i]) + indiv_reward_batch[i].append(reward[i]) + indiv_next_obs_batch[i].append(next_state[i]) + + global_state_batch.append(torch.cat(state)) + global_actions_batch.append(torch.cat(action)) + global_next_state_batch.append(torch.cat(next_state)) + done_batch.append(done) + + global_state_batch = torch.stack(global_state_batch) + global_actions_batch = torch.stack(global_actions_batch) + global_next_state_batch = torch.stack(global_next_state_batch) + done_batch = torch.stack(done_batch) + indiv_obs_batch = torch.stack( + [torch.FloatTensor(obs) for obs in indiv_obs_batch] + ) + indiv_action_batch = torch.stack( + [torch.FloatTensor(act) for act in indiv_action_batch] + ) + indiv_reward_batch = torch.stack( + [torch.FloatTensor(rew) for rew in indiv_reward_batch] + ) + indiv_next_obs_batch = torch.stack( + [torch.FloatTensor(next_obs) for next_obs in indiv_next_obs_batch] + ) + + return ( + indiv_obs_batch, + indiv_action_batch, + indiv_reward_batch, + indiv_next_obs_batch, + global_state_batch, + global_actions_batch, + global_next_state_batch, + done_batch, + ) + + def __len__(self): + """ + Gives number of experiences in buffer currently + :returns: Length of replay memory + """ + return len(self.buffer) diff --git a/genrl/core/rollout_storage.py b/genrl/core/rollout_storage.py index 16d1c721..616434a2 100644 --- a/genrl/core/rollout_storage.py +++ b/genrl/core/rollout_storage.py @@ -102,8 +102,7 @@ def reset(self) -> None: self.full = False def sample( - self, - batch_size: int, + self, batch_size: int, ): """ :param batch_size: (int) Number of element to sample @@ -114,8 +113,7 @@ def sample( return self._get_samples(batch_inds) def _get_samples( - self, - batch_inds: np.ndarray, + self, batch_inds: np.ndarray, ): """ :param batch_inds: (torch.Tensor) @@ -257,3 +255,202 @@ def _get_samples(self, batch_inds: np.ndarray) -> RolloutBufferSamples: self.returns[batch_inds].flatten(), ) return RolloutBufferSamples(*tuple(map(self.to_torch, data))) + + +class MultiAgentRolloutBuffer(BaseBuffer): + """ + Rollout buffer used in on-policy algorithms like MAA2C/MAA3C. + :param num_agents: (int) Max number of agents in the environment + :param buffer_size: (int) Max number of element in the buffer + :param env: (Environment) The environment being trained on + :param device: (torch.device) + :param gae_lambda: (float) Factor for trade-off of bias vs variance for Generalized Advantage Estimator + Equivalent to classic advantage when set to 1. + :param gamma: (float) Discount factor + :param n_envs: (int) Number of parallel environments + """ + + def __init__( + self, + num_agents: int, + buffer_size: int, + env, + device: Union[torch.device, str] = "cpu", + gae_lambda: float = 1, + gamma: float = 0.99, + ): + super(MultiAgentRolloutBuffer, self).__init__(buffer_size, env, device) + + self.buffer_size = buffer_size + self.num_agents = num_agents + self.env = env + self.device = device + self.gae_lambda = gae_lambda + self.gamma = gamma + + self.observations, self.actions, self.rewards, self.advantages = ( + [None] * self.num_agents, + [None] * self.num_agents, + [None] * self.num_agents, + [None] * self.num_agents, + ) + self.returns, self.dones, self.values, self.log_probs = ( + [None] * self.num_agents, + [None] * self.num_agents, + [None] * self.num_agents, + [None] * self.num_agents, + ) + self.generator_ready = False + self.reset() + + def reset(self) -> None: + self.observations = torch.zeros( + *(self.buffer_size, self.env.n_envs, self.num_agents, *self.env.obs_shape) + ) + self.actions = torch.zeros( + *( + self.buffer_size, + self.env.n_envs, + self.num_agents, + *self.env.action_shape, + ) + ) + self.rewards = torch.zeros(self.buffer_size, self.env.n_envs, self.num_agents) + self.returns = torch.zeros(self.buffer_size, self.env.n_envs, self.num_agents) + self.dones = torch.zeros(self.buffer_size, self.env.n_envs, self.num_agents) + self.values = torch.zeros(self.buffer_size, self.env.n_envs, self.num_agents) + self.log_probs = torch.zeros(self.buffer_size, self.env.n_envs, self.num_agents) + self.advantages = torch.zeros( + self.buffer_size, self.env.n_envs, self.num_agents + ) + self.generator_ready = False + super(MultiAgentRolloutBuffer, self).reset() + + def add( + self, + obs: torch.zeros, + action: torch.zeros, + reward: torch.zeros, + done: torch.zeros, + value: torch.Tensor, + log_prob: torch.Tensor, + ) -> None: + """ + :param obs: (torch.zeros) Observation + :param action: (torch.zeros) Action + :param reward: (torch.zeros) + :param done: (torch.zeros) End of episode signal. + :param value: (torch.Tensor) estimated value of the current state + following the current policy. + :param log_prob: (torch.Tensor) log probability of the action + following the current policy. + """ + if len(log_prob.shape) == 0: + # Reshape 0-d tensor to avoid error + log_prob = log_prob.reshape(-1, 1) + + self.observations[self.pos] = obs.detach().clone() + self.actions[self.pos] = action.squeeze().detach().clone() + self.rewards[self.pos] = reward.detach().clone() + self.dones[self.pos] = done.detach().clone() + self.values[self.pos] = ( + value.detach().clone().flatten().reshape(-1, self.num_agents) + ) + self.log_probs[self.pos] = ( + log_prob.detach().clone().flatten().reshape(-1, self.num_agents) + ) + self.pos += 1 + + if self.pos == self.buffer_size: + self.full = True + + def get( + self, batch_size: Optional[int] = None + ) -> Generator[RolloutBufferSamples, None, None]: + assert self.full, "" + indices = np.random.permutation(self.buffer_size * self.env.n_envs) + # Prepare the data + if not self.generator_ready: + for tensor in [ + "observations", + "actions", + "values", + "log_probs", + "advantages", + "returns", + ]: + self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) + self.generator_ready = True + + # Return everything, don't create minibatches + if batch_size is None: + batch_size = self.buffer_size * self.env.n_envs + + start_idx = 0 + while start_idx < self.buffer_size * self.env.n_envs: + yield self._get_samples(indices[start_idx : start_idx + batch_size]) + start_idx += batch_size + + def _get_samples(self, batch_inds: np.ndarray) -> RolloutBufferSamples: + data = ( + self.observations[batch_inds], + self.actions[batch_inds], + self.values[batch_inds].flatten().reshape(-1, self.num_agents), + self.log_probs[batch_inds].flatten().reshape(-1, self.num_agents), + self.advantages[batch_inds].flatten().reshape(-1, self.num_agents), + self.returns[batch_inds].flatten().reshape(-1, self.num_agents), + ) + return RolloutBufferSamples(*tuple(map(self.to_torch, data))) + + def compute_returns_and_advantage( + self, last_value: torch.Tensor, dones: torch.zeros, use_gae: bool = False + ) -> None: + """ + Post-processing step: compute the returns (sum of discounted rewards) + and advantage (A(s) = R - V(S)). + Adapted from Stable-Baselines PPO2. + :param last_value: (torch.Tensor) + :param dones: (torch.zeros) + :param use_gae: (bool) Whether to use Generalized Advantage Estimation + or normal advantage for advantage computation. + """ + last_value = last_value.flatten().reshape(-1, self.num_agents) + + if use_gae: + last_gae_lam = 0 + for step in reversed(range(self.buffer_size)): + if step == self.buffer_size - 1: + next_non_terminal = 1.0 - dones + next_value = last_value + else: + next_non_terminal = 1.0 - self.dones[step + 1] + next_value = self.values[step + 1] + delta = ( + self.rewards[step] + + self.gamma * next_value * next_non_terminal + - self.values[step] + ) + last_gae_lam = ( + delta + + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam + ) + self.advantages[step] = last_gae_lam + self.returns = self.advantages + self.values + else: + # Discounted return with value bootstrap + # Note: this is equivalent to GAE computation + # with gae_lambda = 1.0 + last_return = 0.0 + for step in reversed(range(self.buffer_size)): + if step == self.buffer_size - 1: + next_non_terminal = 1.0 - dones + next_value = last_value + last_return = self.rewards[step] + next_non_terminal * next_value + else: + next_non_terminal = 1.0 - self.dones[step + 1] + last_return = ( + self.rewards[step] + + self.gamma * last_return * next_non_terminal + ) + self.returns[step] = last_return + self.advantages = self.returns - self.values diff --git a/genrl/utils/__init__.py b/genrl/utils/__init__.py index b7f4070d..e0ff5c4e 100644 --- a/genrl/utils/__init__.py +++ b/genrl/utils/__init__.py @@ -20,4 +20,7 @@ noisy_mlp, safe_mean, set_seeds, + onehot_from_logits ) +from genrl.utils.pettingzoo_interface import PettingZooInterface # noqa + diff --git a/genrl/utils/pettingzoo_interface.py b/genrl/utils/pettingzoo_interface.py new file mode 100644 index 00000000..12574d00 --- /dev/null +++ b/genrl/utils/pettingzoo_interface.py @@ -0,0 +1,61 @@ +from abc import ABC, abstractmethod + +import gym +import numpy as np + + +class PettingZooInterface(ABC): + """ + An interface between the PettingZoo API and agents defined in GenRL + + Attributes: + + env (PettingZoo Environment) : The environments in which the agents are acting + agents_list (list) : A list containing all the agent objects present in the environment + """ + + def __init__(self, env, agents_list): + self.env = env + self.agents_list = agents_list + + def get_env_properties(self, network: str): + state_dim = list(self.env.observation_spaces.values())[0].shape[0] + if isinstance(list(self.env.action_spaces.vales())[0], gym.spaces.Discrete): + discrete = True + action_dim = list(self.env.action_spaces.values())[0].n + action_lim = None + elif isinstance(list(self.env.action_spaces.values())[0], gym.spaces.Box): + discrete = False + action_dim = list(self.env.action_spaces.values())[0].shape[0] + action_lim = list(self.env.action_spaces.values())[0].high[0] + else: + NotImplementedError + + return state_dim, action_dim, discrete, action_lim + + def get_actions(self, states, steps, warmup_steps): + if steps < warmup_steps: + actions = {agent: self.env.action_spaces[agent].sample() for key in states} + else: + actions = { + agent: self.agents_list[i].select_action(torch.tensor(states[agent])) + for i, key in enumerate(states) + } + return actions + + def flatten(self, object): + flattened_object = np.array([object[agent] for agent in self.env.agents]) + return flattened_object + + def trainer(self, action): + raise NotImplementedError + + def update_agents( + indiv_reward_batch_i, + obs_batch_i, + global_state_batch, + global_actions_batch, + global_next_state_batch, + next_global_actions, + ): + raise NotImplementedError diff --git a/genrl/utils/utils.py b/genrl/utils/utils.py index 89e53337..9d9123e7 100644 --- a/genrl/utils/utils.py +++ b/genrl/utils/utils.py @@ -37,9 +37,7 @@ def get_model(type_: str, name_: str) -> Union: def mlp( - sizes: Tuple, - activation: str = "relu", - sac: bool = False, + sizes: Tuple, activation: str = "relu", sac: bool = False, ): """ Generates an MLP model given sizes of each layer @@ -199,3 +197,21 @@ def safe_mean(log: Union[torch.Tensor, List[int]]): else: func = np.mean return func(log) + + +def onehot_from_logits(self, logits, eps=0.0): + # get best (according to current policy) actions in one-hot form + argmax_acs = (logits == logits.max(0, keepdim=True)[0]).float() + if eps == 0.0: + return argmax_acs + # get random actions in one-hot form + rand_acs = torch.eye(logits.shape[1])[ + [np.random.choice(range(logits.shape[1]), size=logits.shape[0])] + ] + # chooses between best and random actions using epsilon greedy + return torch.stack( + [ + argmax_acs[i] if r > eps else rand_acs[i] + for i, r in enumerate(torch.rand(logits.shape[0])) + ] + )