Skip to content

Commit

Permalink
barebones maddpg
Browse files Browse the repository at this point in the history
  • Loading branch information
veds12 committed Oct 9, 2020
1 parent ff6e9c7 commit 330adbd
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 0 deletions.
Empty file.
82 changes: 82 additions & 0 deletions genrl/agents/multiagent/maddpg/maddpg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import torch
from genrl.utils import PettingZooInterface
from genrl.agents import DDPG

class MADDPG(ABC):
"""MultiAgent Controller using the MADDPG algorithm
Attributes:
network (str): The network type of the Q-value function of the agents.
Supported types: ["cnn", "mlp"]
create_model (bool): Whether the model of the algo should be created when initialised
batch_size (int): Mini batch size for loading experiences
gamma (float): The discount factor for rewards
layers (:obj:`tuple` of :obj:`int`): Layers in the Neural Network
of the Q-value function
shared_layers(:obj:`tuple` of :obj:`int`): Sizes of shared layers in Actor Critic if using
lr_policy (float): Learning rate for the policy/actor
lr_value (float): Learning rate for the critic
replay_size (int): Capacity of the Replay Buffer
polyak (float): Target model update parameter (1 for hard update)
env (Environment): The environment that the agent is supposed to act on
replay_size (int): Capacity of the Replay Buffer
render (bool): Should the env be rendered during training?
noise (:obj:`ActionNoise`): Action Noise function added to aid in exploration
noise_std (float): Standard deviation of the action noise distribution
seed (int): Seed for randomness
device (str): Hardware being used for training. Options:
["cuda" -> GPU, "cpu" -> CPU]
"""

def __init__(self, *args, env, replay_size: int = int(1e6), render: bool = False, noise: ActionNoise = None, noise_std: float = 0.2, warmup_steps=1000, **kwargs):
self.env = env
self.num_agents = self.env.num_agents
self.replay_buffer = MultiAgentReplayBuffer(self.num_agents, buffer_maxlen)
self.agents = [DDPG(*args, noise, noise_std, **kwargs) for agent in self.env.agents]
self.EnvInterface = PettingZooInterface(self.env, self.agents)
self.render = render
self.warmup_steps = warmup_steps

def update(self, batch_size):
obs_batch, indiv_action_batch, indiv_reward_batch, next_obs_batch, \
global_state_batch, global_actions_batch, global_next_state_batch, done_batch = self.replay_buffer.sample(batch_size)
for i in range(self.num_agents):
obs_batch_i = obs_batch[i]
indiv_action_batch_i = indiv_action_batch[i]
indiv_reward_batch_i = indiv_reward_batch[i]
next_obs_batch_i = next_obs_batch[i]
next_global_actions = []
next_obs_batch_i, indiv_next_action, next_global_actions = self.EnvInterface.trainer(indiv_next_action)
next_global_actions = torch.cat([next_actions_i for next_actions_i in next_global_actions], 1)
self.EnvInterface.update_agents(indiv_reward_batch_i, obs_batch_i, global_state_batch, global_actions_batch, global_next_state_batch, next_global_actions)

def train(self, max_episode, max_steps, batch_size):
episode_rewards = []
for episode in range(max_episode):
states = self.env.reset()
episode_reward = 0
step = -1
for step in range(max_steps):
if self.render:
self.env.render(mode='human')

step += 1
actions = self.EnvInterface.get_actions(states, steps, warmup_steps))
next_states, rewards, dones, _ = self.env.step(actions)
rewards = self.EnvInterface.flatten(rewards)
episode_reward += np.mean(agent_rewards)
dones = self.EnvInterface.flatten(dones)
if all(dones) or step == max_steps - 1:
dones = [1 for _ in range(self.num_agents)]
self.replay_buffer.push(states, actions, rewards, next_states, dones)
episode_rewards.append(episode_reward)
print(f"Episode: {episode + 1} | Steps Taken: {step +1} | Reward {episode_reward}")
break
else:
dones = [0 for _ in range(self.num_agents)]
self.replay_buffer.push(states, actions, rewards, next_states, dones)
states = next_states
if len(self.replay_buffer) > batch_size:
self.update(batch_size)

30 changes: 30 additions & 0 deletions genrl/agents/multiagent/offpolicy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import collections
import torch
import torch.nn as nn
import torch.optim as opt
from genrl.core import MultiAgentReplayBuffer
from abc import ABC
from genrl.utils import MutiAgentEnvInterface

class MultiAgentOffPolicy(ABC):
"""Base class for multiagent algorithms with OffPolicy agents
Attributes:
network (str): The network type of the Q-value function.
Supported types: ["cnn", "mlp"]
env (Environment): The environment that the agent is supposed to act on
agents (list) : A list of all the agents to be used
create_model (bool): Whether the model of the algo should be created when initialised
batch_size (int): Mini batch size for loading experiences
gamma (float): The discount factor for rewards
layers (:obj:`tuple` of :obj:`int`): Layers in the Neural Network
of the Q-value function
lr_policy (float): Learning rate for the policy/actor
lr_value (float): Learning rate for the Q-value function
replay_size (int): Capacity of the Replay Buffer
seed (int): Seed for randomness
render (bool): Should the env be rendered during training?
device (str): Hardware being used for training. Options:
["cuda" -> GPU, "cpu" -> CPU]
"""
raise NotImplementedError
34 changes: 34 additions & 0 deletions genrl/environments/pettingzoo_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from abc import ABC, abstractmethod
import numpy as np

class PettingZooInterface(ABC):
"""
An interface between the PettingZoo API and agents define in GenRL
Attributes:
env (PettingZoo Environment) : The environments in which the agents are acting
agents_list (list) : A list containing all the agent objects present in the environment
"""

def __init__(self, env, agents_list):
self.env = env
self.agents_list = agents_list

def get_actions(self, states, steps, warmup_steps):
if steps < warmup_steps:
actions = {key : self.env.action_spaces[key].sample() for key in states}
else:
actions = {key : self.agents_list[i].select_action(states[key]) for i, key in enumerate(states)}
return actions

def flatten(self, object):
flattened_object = np.array([object[agent] for agent in self.env.agents])
return flattened_object

def trainer(self, action):
raise NotImplementedError

def update_agents(indiv_reward_batch_i, obs_batch_i, global_state_batch, global_actions_batch, global_next_state_batch, next_global_actions):
raise NotImplementedError

2 changes: 2 additions & 0 deletions genrl/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,3 +377,5 @@ def onehot_from_logits(self, logits, eps=0.0):
for i, r in enumerate(torch.rand(logits.shape[0]))
]
)


0 comments on commit 330adbd

Please sign in to comment.