Skip to content

Commit

Permalink
CEM agent
Browse files Browse the repository at this point in the history
  • Loading branch information
hades-rp2010 committed Oct 5, 2020
1 parent d3830e0 commit a90e8d0
Show file tree
Hide file tree
Showing 6 changed files with 250 additions and 0 deletions.
1 change: 1 addition & 0 deletions genrl/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,6 @@
from genrl.agents.deep.sac.sac import SAC # noqa
from genrl.agents.deep.td3.td3 import TD3 # noqa
from genrl.agents.deep.vpg.vpg import VPG # noqa
from genrl.agents.modelbased.base import ModelBasedAgent

from genrl.agents.bandits.multiarmed.base import MABAgent # noqa; noqa; noqa
Empty file.
66 changes: 66 additions & 0 deletions genrl/agents/modelbased/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from abc import ABC

import numpy as np
import torch


class Planner:
def __init__(self, initial_state, dynamics_model=None):
if dynamics_model is not None:
self.dynamics_model = dynamics_model
self.initial_state = initial_state

def _learn_dynamics_model(self, state):
raise NotImplementedError

def plan(self):
raise NotImplementedError

def execute_actions(self):
raise NotImplementedError


class ModelBasedAgent(ABC):
def __init__(self, env, planner=None, render=False, device="cpu"):
self.env = env
self.planner = planner
self.render = render
self.device = torch.device(device)

def plan(self):
"""
To be used to plan out a sequence of actions
"""
if self.planner is not None:
raise ValueError("Provide a planner to plan for the environment")
self.planner.plan()

def generate_data(self):
"""
To be used to generate synthetic data via a model (may be learnt or specified beforehand)
"""
raise NotImplementedError

def value_equivalence(self, state_space):
"""
To be used for approximate value estimation methods e.g. Value Iteration Networks
"""
raise NotImplementedError

def update_params(self):
"""
Update the parameters (Parameters of the learnt model and/or Parameters of the policy being used)
"""
raise NotImplementedError

def get_hyperparans(self):
raise NotImplementedError

def get_logging_params(self):
raise NotImplementedError

def _load_weights(self, weights):
raise NotImplementedError

def empty_logs(self):
raise NotImplementedError
Empty file.
171 changes: 171 additions & 0 deletions genrl/agents/modelbased/cem/cem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from genrl.agents import ModelBasedAgent
from genrl.core import RolloutBuffer
from genrl.utils import get_env_properties, get_model, safe_mean


class CEM(ModelBasedAgent):
def __init__(
self,
*args,
network: str = "mlp",
policy_layers: tuple = (100,),
percentile: int = 70,
**kwargs
):
super(CEM, self).__init__(*args, **kwargs)
self.network = network
self.rollout_size = int(1e4)
self.rollout = RolloutBuffer(self.rollout_size, self.env)
self.policy_layers = policy_layers
self.percentile = percentile

self._create_model()
self.empty_logs()

def _create_model(self):
self.state_dim, self.action_dim, discrete, action_lim = get_env_properties(
self.env, self.network
)
self.agent = get_model("p", self.network)(
self.state_dim,
self.action_dim,
self.policy_layers,
"V",
discrete,
action_lim,
)
self.optim = torch.optim.Adam(
self.agent.parameters(), lr=1e-3
) # make this a hyperparam

def plan(self, timesteps=1e4):
state = self.env.reset()
self.rollout.reset()
_, _ = self.collect_rollouts(state)
return (
self.rollout.observations,
self.rollout.actions,
torch.sum(self.rollout.rewards).detach(),
)

def select_elites(self, states_batch, actions_batch, rewards_batch):
reward_threshold = np.percentile(rewards_batch, self.percentile)
print(reward_threshold)
elite_states = [
s.unsqueeze(0)
for i in range(len(states_batch))
if rewards_batch[i] >= reward_threshold
for s in states_batch[i]
]
elite_actions = [
a.unsqueeze(0)
for i in range(len(actions_batch))
if rewards_batch[i] >= reward_threshold
for a in actions_batch[i]
]

return torch.cat(elite_states, dim=0), torch.cat(elite_actions, dim=0)

def select_action(self, state):
state = torch.as_tensor(state).float()
action, dist = self.agent.get_action(state)
return action

def update_params(self):
sess = [self.plan() for _ in range(100)]
batch_states, batch_actions, batch_rewards = zip(*sess)
elite_states, elite_actions = self.select_elites(
batch_states, batch_actions, batch_rewards
)
print(elite_actions.shape)
action_probs = self.agent.forward(torch.as_tensor(elite_states).float())
print(action_probs.shape)
print(self.action_dim)
loss = F.cross_entropy(
action_probs.view(-1, self.action_dim),
torch.as_tensor(elite_actions).long().view(-1),
)
self.logs["crossentropy_loss"].append(loss.item())
loss.backward()
torch.nn.utils.clip_grad_norm_(self.agent.parameters(), 0.5)
self.optim.step()

def get_traj_loss(self, values, dones):
# No need for this here
pass

def collect_rollouts(self, state: torch.Tensor):
"""Function to collect rollouts
Collects rollouts by playing the env like a human agent and inputs information into
the rollout buffer.
Args:
state (:obj:`torch.Tensor`): The starting state of the environment
Returns:
values (:obj:`torch.Tensor`): Values of states encountered during the rollout
dones (:obj:`torch.Tensor`): Game over statuses of each environment
"""
for i in range(self.rollout_size):
action = self.select_action(state)

next_state, reward, dones, _ = self.env.step(action)

if self.render:
self.env.render()

self.rollout.add(
state,
action.reshape(self.env.n_envs, 1),
reward,
dones,
torch.tensor(0),
torch.tensor(0),
)

state = next_state

self.collect_rewards(dones, i)

if dones:
break

return torch.tensor(0), dones

def collect_rewards(self, dones: torch.Tensor, timestep: int):
"""Helper function to collect rewards
Runs through all the envs and collects rewards accumulated during rollouts
Args:
dones (:obj:`torch.Tensor`): Game over statuses of each environment
timestep (int): Timestep during rollout
"""
for i, done in enumerate(dones):
if done or timestep == self.rollout_size - 1:
self.rewards.append(self.env.episode_reward[i].detach().clone())
self.env.reset_single_env(i)

def get_hyperparams(self):
# return self.agent.get_hyperparams()
pass

def get_logging_params(self):
logs = {
"crossentropy_loss": safe_mean(self.logs["crossentropy_loss"]),
"mean_reward": safe_mean(self.rewards),
}
return logs

def empty_logs(self):
# self.agent.empty_logs()
self.logs = {}
self.logs["crossentropy_loss"] = []
self.rewards = []
12 changes: 12 additions & 0 deletions tests/test_deep/test_agents/test_cem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import gym

from genrl.agents.modelbased.cem.cem import CEM
from genrl.environments import VectorEnv
from genrl.trainers import OnPolicyTrainer


def test_CEM():
env = VectorEnv("CartPole-v0", 1)
algo = CEM(env, percentile=70, policy_layers=[100])
trainer = OnPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1)
trainer.train()

0 comments on commit a90e8d0

Please sign in to comment.