Skip to content

Commit

Permalink
Add docstrings to core library components (#273)
Browse files Browse the repository at this point in the history
Documentation style https://google.github.io/styleguide/pyguide.html

Co-authored-by: anishdiwan <[email protected]>
  • Loading branch information
anishhdiwan and anishdiwan authored Jan 31, 2024
1 parent 165652c commit cba782c
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 0 deletions.
20 changes: 20 additions & 0 deletions rl_games/algos_torch/a2c_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,20 @@


class A2CAgent(a2c_common.ContinuousA2CBase):
"""Continuous PPO Agent
The A2CAgent class inerits from the continuous asymmetric actor-critic class and makes modifications for PPO.
"""
def __init__(self, base_name, params):
"""Initialise the algorithm with passed params
Args:
base_name (:obj:`str`): Name passed on to the observer and used for checkpoints etc.
params (:obj `dict`): Algorithm parameters
"""

a2c_common.ContinuousA2CBase.__init__(self, base_name, params)
obs_shape = self.obs_shape
build_config = {
Expand Down Expand Up @@ -75,6 +87,14 @@ def get_masked_action_values(self, obs, action_masks):
assert False

def calc_gradients(self, input_dict):
"""Compute gradients needed to step the networks of the algorithm.
Core algo logic is defined here
Args:
input_dict (:obj:`dict`): Algo inputs as a dict.
"""
value_preds_batch = input_dict['old_values']
old_action_log_probs_batch = input_dict['old_logp_actions']
advantage = input_dict['advantages']
Expand Down
19 changes: 19 additions & 0 deletions rl_games/algos_torch/a2c_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,19 @@


class DiscreteA2CAgent(a2c_common.DiscreteA2CBase):
"""Discrete PPO Agent
The DiscreteA2CAgent class inerits from the discrete asymmetric actor-critic class and makes modifications for PPO.
"""
def __init__(self, base_name, params):
"""Initialise the algorithm with passed params
Args:
base_name (:obj:`str`): Name passed on to the observer and used for checkpoints etc.
params (:obj `dict`): Algorithm parameters
"""
a2c_common.DiscreteA2CBase.__init__(self, base_name, params)
obs_shape = self.obs_shape

Expand Down Expand Up @@ -108,6 +119,14 @@ def train_actor_critic(self, input_dict):
return self.train_result

def calc_gradients(self, input_dict):
"""Compute gradients needed to step the networks of the algorithm.
Core algo logic is defined here
Args:
input_dict (:obj:`dict`): Algo inputs as a dict.
"""
value_preds_batch = input_dict['old_values']
old_action_log_probs_batch = input_dict['old_logp_actions']
advantage = input_dict['advantages']
Expand Down
8 changes: 8 additions & 0 deletions rl_games/common/env_configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ def create_env(name, **kwargs):
env = wrappers.TimeLimit(env, steps_limit)
return env

# Dictionary of env_name as key and a sub-dict containing env_type and a env-creator function
configurations = {
'CartPole-v1' : {
'vecenv_type' : 'RAY',
Expand Down Expand Up @@ -458,4 +459,11 @@ def get_obs_and_action_spaces_from_config(config):


def register(name, config):
"""Add a new key-value pair to the known environments (configurations dict).
Args:
name (:obj:`str`): Name of the env to be added.
config (:obj:`dict`): Dictionary with env type and a creator function.
"""
configurations[name] = config
25 changes: 25 additions & 0 deletions rl_games/common/object_factory.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,39 @@
class ObjectFactory:
"""General-purpose class to instantiate some other base class from rl_games. Usual use case it to instantiate algos, players etc.
The ObjectFactory class is used to dynamically create any other object using a builder function (typically a lambda function).
"""

def __init__(self):
"""Initialise a dictionary of builders with keys as `str` and values as functions.
"""
self._builders = {}

def register_builder(self, name, builder):
"""Register a passed builder by adding to the builders dict.
Initialises runners and players for all algorithms available in the library using `rl_games.common.object_factory.ObjectFactory`
Args:
name (:obj:`str`): Key of the added builder.
builder (:obj `func`): Function to return the requested object
"""
self._builders[name] = builder

def set_builders(self, builders):
self._builders = builders

def create(self, name, **kwargs):
"""Create the requested object by calling a registered builder function.
Args:
name (:obj:`str`): Key of the requested builder.
**kwargs: Arbitrary kwargs needed for the builder function
"""
builder = self._builders.get(name)
if not builder:
raise ValueError(name)
Expand Down
47 changes: 47 additions & 0 deletions rl_games/common/vecenv.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,20 @@
import torch

class RayWorker:
"""Wrapper around a third-party (gym for example) environment class that enables parallel training.
The RayWorker class wraps around another environment class to enable the use of this
environment within an asynchronous parallel training setup
"""
def __init__(self, config_name, config):
"""Initialise the class. Sets up the environment creator using the `rl_games.common.env_configurations.configuraitons` dict
Args:
config_name (:obj:`str`): Key of the environment to create.
config: Misc. kwargs passed on to the environment creator function
"""
self.env = configurations[config_name]['env_creator'](**config)

def _obs_to_fp32(self, obs):
Expand All @@ -27,6 +40,12 @@ def _obs_to_fp32(self, obs):
return obs

def step(self, action):
"""Step the environment and reset if done
Args:
action (type depends on env): Action to take.
"""
next_state, reward, is_done, info = self.env.step(action)

if np.isscalar(is_done):
Expand Down Expand Up @@ -95,9 +114,23 @@ def get_env_info(self):


class RayVecEnv(IVecEnv):
"""Main env class that manages several `rl_games.common.vecenv.Rayworker` objects for parallel training
The RayVecEnv class manages a set of individual environments and wraps around the methods from RayWorker.
Each worker is executed asynchronously.
"""
import ray

def __init__(self, config_name, num_actors, **kwargs):
"""Initialise the class. Sets up the config for the environment and creates individual workers to manage.
Args:
config_name (:obj:`str`): Key of the environment to create.
num_actors (:obj:`int`): Number of environments (actors) to create
**kwargs: Misc. kwargs passed on to the environment creator function within the RayWorker __init__
"""
self.config_name = config_name
self.num_actors = num_actors
self.use_torch = False
Expand Down Expand Up @@ -131,6 +164,14 @@ def __init__(self, config_name, num_actors, **kwargs):
self.concat_func = np.concatenate

def step(self, actions):
"""Step all individual environments (using the created workers).
Returns a concatenated array of observations, rewards, done states, and infos if the env allows concatenation.
Else returns a nested dict.
Args:
action (type depends on env): Action to take.
"""
newobs, newstates, newrewards, newdones, newinfos = [], [], [], [], []
res_obs = []
if self.num_agents == 1:
Expand Down Expand Up @@ -218,6 +259,12 @@ def reset(self):
vecenv_config = {}

def register(config_name, func):
"""Add an environment type (for example RayVecEnv) to the list of available types `rl_games.common.vecenv.vecenv_config`
Args:
config_name (:obj:`str`): Key of the environment to create.
func (:obj:`func`): Function that creates the environment
"""
vecenv_config[config_name] = func

def create_vec_env(config_name, num_actors, **kwargs):
Expand Down
41 changes: 41 additions & 0 deletions rl_games/torch_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,24 @@ def _override_sigma(agent, args):


class Runner:
"""Runs training/inference (playing) procedures as per a given configuration.
The Runner class provides a high-level API for instantiating agents for either training or playing
with an RL algorithm. It further logs training metrics.
"""

def __init__(self, algo_observer=None):
"""Initialise the runner instance with algorithms and observers.
Initialises runners and players for all algorithms available in the library using `rl_games.common.object_factory.ObjectFactory`
Args:
algo_observer (:obj:`rl_games.common.algo_observer.AlgoObserver`, optional): Algorithm observer that logs training metrics.
Defaults to `rl_games.common.algo_observer.DefaultAlgoObserver`
"""

self.algo_factory = object_factory.ObjectFactory()
self.algo_factory.register_builder('a2c_continuous', lambda **kwargs : a2c_continuous.A2CAgent(**kwargs))
self.algo_factory.register_builder('a2c_discrete', lambda **kwargs : a2c_discrete.DiscreteA2CAgent(**kwargs))
Expand All @@ -55,6 +71,13 @@ def reset(self):
pass

def load_config(self, params):
"""Loads passed config params.
Args:
params (:obj:`dict`): Parameters passed in as a dict obtained from a yaml file or some other config format.
"""

self.seed = params.get('seed', None)
if self.seed is None:
self.seed = int(time.time())
Expand Down Expand Up @@ -109,13 +132,25 @@ def load(self, yaml_config):
self.load_config(params=self.default_config)

def run_train(self, args):
"""Run the training procedure from the algorithm passed in.
Args:
args (:obj:`dict`): Train specific args passed in as a dict obtained from a yaml file or some other config format.
"""
print('Started to train')
agent = self.algo_factory.create(self.algo_name, base_name='run', params=self.params)
_restore(agent, args)
_override_sigma(agent, args)
agent.train()

def run_play(self, args):
"""Run the inference procedure from the algorithm passed in.
Args:
args (:obj:`dict`): Playing specific args passed in as a dict obtained from a yaml file or some other config format.
"""
print('Started to play')
player = self.create_player()
_restore(player, args)
Expand All @@ -129,6 +164,12 @@ def reset(self):
pass

def run(self, args):
"""Run either train/play depending on the args.
Args:
args (:obj:`dict`): Args passed in as a dict obtained from a yaml file or some other config format.
"""
if args['train']:
self.run_train(args)
elif args['play']:
Expand Down

0 comments on commit cba782c

Please sign in to comment.