Skip to content

Latest commit

 

History

History
135 lines (93 loc) · 3.82 KB

README.md

File metadata and controls

135 lines (93 loc) · 3.82 KB

Modular

The core principle of cherry is to organise an RL framework around disjoint portable module. These portable modules represent the Agent/Environment de-coupling presented in spinning up. We compose the Agent as its function approximation (Deep NN) representation. You can check out our Agents and Environment

Agent

The Agent API is structured as follows

class AgentName():

  def __init__(self, cfgs, model=None, model_file=None,
               device=None, log_level='info'):

  @property
  def state(self):

  @state.setter:
  def state(self)

  @state.getter:
  def state(self)

  @property
  def reward(self):

  @reward.setter
  def reward(self):

  @reward.getter
  def reward(self):

  @property
  def ep_reward(self):

  @ep_reward.setter
  def ep_reward(self):

  @ep_reward.getter
  def ep_reward(self):

  def reset(self):

  def flash(self):

  def init(self, model_file):

  def augment_state(self):

  def append_state(self, state):

  def action(self, state, deterministic=False):

  def train(self, env, train_cfgs, gitsha, model_dest):

  def play(self, env, test_cfgs, gitsha):

  def optimize(self):

Architecture

The Architecture API is structures as follows

class ModelName(torch.nn.Module):

  def __init__(self, input_shape, state_size, action_size, device):

    super(ModelName, self).__init__()

    self.device = device
    self.state_size = state_size
    self.action_size = action_size
    self.input_shape = input_shape

  def init_weights(self, m):

  def forward(self, x):

Environment

The Environment API is structures as follows

class EnvironmenName():

  def __init__(self, cfgs, play=False):

    self.env = None
    self.seed = cfgs.get('seed')
    self.env_name = cfgs.get('name')
    self.env_solution = cfgs.get('env_solution')

  def reset(self):

  def step(self, action):

  def close(self):

  def update_env(self, update_fn, **kwargs):

Reproducibility

Pseudorandom number gen

Taking a cue out of PyTorch and our sanity we set the random seeds from yaml config files as follows.

  env = make_atari(self.env_name)
  self.env = wrap_deepmind(env)
  self.env.seed(self.seed)
  torch.manual_seed(self.seed)

If you are using random number generator from within numpy in your codes. Please do set the same random seed following the PyTorch convention.

Tracking code & models

Training/testing is setup through yaml config files. Sample files located here. During training, the config file & model weights are saved at model_dest and appended with commit-gitsha of the repo for traceability & reproducibility. We follow the PyTorch Hub's convention of using first 8 chars of commit gitsha as for tagging files & models. F.ex

# commit gitsha
git log -n 1
commit c099dff57dac5dbd4bccce7cfb5f8e98c145ee2d (tag: v1.0)
Author: Freja <[email protected]>
Date:   Tue Jan 28 14:26:03 2020 +0100

    readme update
# <model_dest> from configs/control-ddpg.yaml
ls <model_dest>
-rw-rw-r-- 1 moabit 71425 Jan 27 16:14 agent-000030000-c099df.pth
-rw-rw-r-- 1 moabit 71425 Jan 27 16:15 agent-final-c099df.pth
-rw-rw-r-- 1 moabit  1674 Jan 27 16:12 control-ddpg-c099df.yaml