Skip to content

Commit

Permalink
Docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
hades-rp2010 committed Oct 17, 2020
1 parent f86b046 commit f5a189d
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 71 deletions.
28 changes: 5 additions & 23 deletions genrl/agents/modelbased/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import torch

from genrl.agents import BaseAgent


class Planner:
def __init__(self, initial_state, dynamics_model=None):
Expand All @@ -19,12 +21,10 @@ def execute_actions(self):
raise NotImplementedError


class ModelBasedAgent(ABC):
def __init__(self, env, planner=None, render=False, device="cpu"):
self.env = env
class ModelBasedAgent(BaseAgent):
def __init__(self, *args, planner=None, **kwargs):
super(ModelBasedAgent, self).__init__(*args, **kwargs)
self.planner = planner
self.render = render
self.device = torch.device(device)

def plan(self):
"""
Expand All @@ -45,21 +45,3 @@ def value_equivalence(self, state_space):
To be used for approximate value estimation methods e.g. Value Iteration Networks
"""
raise NotImplementedError

def update_params(self):
"""
Update the parameters (Parameters of the learnt model and/or Parameters of the policy being used)
"""
raise NotImplementedError

def get_hyperparans(self):
raise NotImplementedError

def get_logging_params(self):
raise NotImplementedError

def _load_weights(self, weights):
raise NotImplementedError

def empty_logs(self):
raise NotImplementedError
142 changes: 104 additions & 38 deletions genrl/agents/modelbased/cem/cem.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any, Dict

import numpy as np
import torch
import torch.nn.functional as F
Expand All @@ -8,28 +10,46 @@


class CEM(ModelBasedAgent):
"""Cross Entropy method algorithm (CEM)
Attributes:
network (str): The type of network to be used
env (Environment): The environment the agent is supposed to act on
create_model (bool): Whether the model of the algo should be created when initialised
policy_layers (:obj:`tuple` of :obj:`int`): Layers in the Neural Network of the policy
lr_policy (float): learning rate of the policy
percentile (float): Top percentile of rewards to consider as elite
simulations_per_epoch (int): Number of simulations to perform before taking a gradient step
rollout_size (int): Capacity of the replay buffer
render (bool): Whether to render the environment or not
device (str): Hardware being used for training. Options:
["cuda" -> GPU, "cpu" -> CPU]
"""

def __init__(
self,
*args,
network: str = "mlp",
policy_layers: tuple = (100,),
lr_policy=1e-3,
percentile: int = 70,
percentile: float = 70,
simulations_per_epoch: int = 1000,
rollout_size,
**kwargs
):
super(CEM, self).__init__(*args, **kwargs)
self.network = network
self.rollout_size = rollout_size
self.rollout = RolloutBuffer(self.rollout_size, self.env)
self.policy_layers = policy_layers
self.lr_policy = lr_policy
self.percentile = percentile
self.simulations_per_epoch = simulations_per_epoch

self._create_model()
self.empty_logs()

def _create_model(self):
"""Function to initialize the Policy
This will create the Policy net for the CEM agent
"""
self.state_dim, self.action_dim, discrete, action_lim = get_env_properties(
self.env, self.network
)
Expand All @@ -44,35 +64,74 @@ def _create_model(self):
self.optim = torch.optim.Adam(self.agent.parameters(), lr=self.lr_policy)

def plan(self):
"""Function to plan out one episode
Returns:
states (:obj:`list` of :obj:`torch.Tensor`): Batch of states the agent encountered in the episode
actions (:obj:`list` of :obj:`torch.Tensor`): Batch of actions the agent took in the episode
rewards (:obj:`torch.Tensor`): The episode reward obtained
"""
state = self.env.reset()
self.rollout.reset()
states, actions = self.collect_rollouts(state)
return (states, actions, self.rewards[-1])

def select_elites(self, states_batch, actions_batch, rewards_batch):
"""Function to select the elite states and elite actions based on the episode reward
Args:
states_batch (:obj:`list` of :obj:`torch.Tensor`): Batch of states
actions_batch (:obj:`list` of :obj:`torch.Tensor`): Batch of actions
rewards_batch (:obj:`list` of :obj:`torch.Tensor`): Batch of rewards
Returns:
elite_states (:obj:`torch.Tensor`): Elite batch of states based on episode reward
elite_actions (:obj:`torch.Tensor`): Actions the agent took during the elite batch of states
"""
reward_threshold = np.percentile(rewards_batch, self.percentile)
elite_states = [
s.unsqueeze(0).clone()
for i in range(len(states_batch))
if rewards_batch[i] >= reward_threshold
for s in states_batch[i]
]
elite_actions = [
a.unsqueeze(0).clone()
for i in range(len(actions_batch))
if rewards_batch[i] >= reward_threshold
for a in actions_batch[i]
]

return torch.cat(elite_states, dim=0), torch.cat(elite_actions, dim=0)
elite_states = torch.cat(
[
s.unsqueeze(0).clone()
for i in range(len(states_batch))
if rewards_batch[i] >= reward_threshold
for s in states_batch[i]
],
dim=0,
)
elite_actions = torch.cat(
[
a.unsqueeze(0).clone()
for i in range(len(actions_batch))
if rewards_batch[i] >= reward_threshold
for a in actions_batch[i]
],
dim=0,
)

return elite_states, elite_actions

def select_action(self, state):
"""Select action given state
Action selection policy for the Cross Entropy agent
Args:
state (:obj:`torch.Tensor`): Current state of the agent
Returns:
action (:obj:`torch.Tensor`): Action taken by the agent
"""
state = torch.as_tensor(state).float()
action, dist = self.agent.get_action(state)
return action, torch.zeros((1, self.env.n_envs)), dist.log_prob(action).cpu()
return action

def update_params(self):
sess = [self.plan() for _ in range(100)]
"""Updates the the Policy network of the CEM agent
Function to update the policy network
"""
sess = [self.plan() for _ in range(self.simulations_per_epoch)]
batch_states, batch_actions, batch_rewards = zip(*sess)
elite_states, elite_actions = self.select_elites(
batch_states, batch_actions, batch_rewards
Expand Down Expand Up @@ -101,13 +160,13 @@ def collect_rollouts(self, state: torch.Tensor):
state (:obj:`torch.Tensor`): The starting state of the environment
Returns:
values (:obj:`torch.Tensor`): Values of states encountered during the rollout
dones (:obj:`torch.Tensor`): Game over statuses of each environment
states (:obj:`list`): list of states the agent encountered during the episode
actions (:obj:`list`): list of actions the agent took in the corresponding states
"""
states = []
actions = []
for i in range(self.rollout_size):
action, value, log_probs = self.select_action(state)
action = self.select_action(state)
states.append(state)
actions.append(action)

Expand All @@ -116,20 +175,11 @@ def collect_rollouts(self, state: torch.Tensor):
if self.render:
self.env.render()

self.rollout.add(
state,
action.reshape(self.env.n_envs, 1),
reward,
dones,
value,
log_probs.detach(),
)

state = next_state

self.collect_rewards(dones, i)

if dones:
if torch.any(dones.byte()):
break

return states, actions
Expand All @@ -146,24 +196,40 @@ def collect_rewards(self, dones: torch.Tensor, timestep: int):
for i, done in enumerate(dones):
if done or timestep == self.rollout_size - 1:
self.rewards.append(self.env.episode_reward[i].detach().clone())
# self.env.reset_single_env(i)

def get_hyperparams(self):
def get_hyperparams(self) -> Dict[str, Any]:
"""Get relevant hyperparameters to save
Returns:
hyperparams (:obj:`dict`): Hyperparameters to be saved
weights (:obj:`torch.Tensor`): Neural network weights
"""
hyperparams = {
"network": self.network,
"lr_policy": self.lr_policy,
"rollout_size": self.rollout_size,
}
return hyperparams
return hyperparams, self.agent.state_dict()

def _load_weights(self, weights) -> None:
self.agent.load_state_dict(weights)

def get_logging_params(self) -> Dict[str, Any]:
"""Gets relevant parameters for logging
def get_logging_params(self):
Returns:
logs (:obj:`dict`): Logging parameters for monitoring training
"""
logs = {
"crossentropy_loss": safe_mean(self.logs["crossentropy_loss"]),
"mean_reward": safe_mean(self.rewards),
}

self.empty_logs()
return logs

def empty_logs(self):
"""Empties logs"""
self.logs = {}
self.logs["crossentropy_loss"] = []
self.rewards = []
1 change: 1 addition & 0 deletions tests/test_agents/test_modelbased/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from tests.test_agents.test_modelbased.test_cem import TestCEM
23 changes: 23 additions & 0 deletions tests/test_agents/test_modelbased/test_cem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import shutil

from genrl.agents import CEM
from genrl.environments import VectorEnv
from genrl.trainers import OnPolicyTrainer


class TestCEM:
def test_CEM(self):
env = VectorEnv("CartPole-v0", 1)
algo = CEM(
"mlp",
env,
percentile=70,
policy_layers=[100],
rollout_size=100,
simulations_per_epoch=100,
)
trainer = OnPolicyTrainer(
algo, env, log_mode=["csv"], logdir="./logs", epochs=1
)
trainer.train()
shutil.rmtree("./logs")
10 changes: 0 additions & 10 deletions tests/test_deep/test_agents/test_cem.py

This file was deleted.

0 comments on commit f5a189d

Please sign in to comment.