diff --git a/rl_games/algos_torch/a2c_continuous.py b/rl_games/algos_torch/a2c_continuous.py index b731a4ed..78795a2b 100644 --- a/rl_games/algos_torch/a2c_continuous.py +++ b/rl_games/algos_torch/a2c_continuous.py @@ -10,8 +10,20 @@ class A2CAgent(a2c_common.ContinuousA2CBase): + """Continuous PPO Agent + The A2CAgent class inerits from the continuous asymmetric actor-critic class and makes modifications for PPO. + + """ def __init__(self, base_name, params): + """Initialise the algorithm with passed params + + Args: + base_name (:obj:`str`): Name passed on to the observer and used for checkpoints etc. + params (:obj `dict`): Algorithm parameters + + """ + a2c_common.ContinuousA2CBase.__init__(self, base_name, params) obs_shape = self.obs_shape build_config = { @@ -75,6 +87,14 @@ def get_masked_action_values(self, obs, action_masks): assert False def calc_gradients(self, input_dict): + """Compute gradients needed to step the networks of the algorithm. + + Core algo logic is defined here + + Args: + input_dict (:obj:`dict`): Algo inputs as a dict. + + """ value_preds_batch = input_dict['old_values'] old_action_log_probs_batch = input_dict['old_logp_actions'] advantage = input_dict['advantages'] diff --git a/rl_games/algos_torch/a2c_discrete.py b/rl_games/algos_torch/a2c_discrete.py index 467d8d86..312bd17c 100644 --- a/rl_games/algos_torch/a2c_discrete.py +++ b/rl_games/algos_torch/a2c_discrete.py @@ -13,8 +13,19 @@ class DiscreteA2CAgent(a2c_common.DiscreteA2CBase): + """Discrete PPO Agent + The DiscreteA2CAgent class inerits from the discrete asymmetric actor-critic class and makes modifications for PPO. + + """ def __init__(self, base_name, params): + """Initialise the algorithm with passed params + + Args: + base_name (:obj:`str`): Name passed on to the observer and used for checkpoints etc. + params (:obj `dict`): Algorithm parameters + + """ a2c_common.DiscreteA2CBase.__init__(self, base_name, params) obs_shape = self.obs_shape @@ -108,6 +119,14 @@ def train_actor_critic(self, input_dict): return self.train_result def calc_gradients(self, input_dict): + """Compute gradients needed to step the networks of the algorithm. + + Core algo logic is defined here + + Args: + input_dict (:obj:`dict`): Algo inputs as a dict. + + """ value_preds_batch = input_dict['old_values'] old_action_log_probs_batch = input_dict['old_logp_actions'] advantage = input_dict['advantages'] diff --git a/rl_games/common/env_configurations.py b/rl_games/common/env_configurations.py index d8b335e3..43c8ebe1 100644 --- a/rl_games/common/env_configurations.py +++ b/rl_games/common/env_configurations.py @@ -253,6 +253,7 @@ def create_env(name, **kwargs): env = wrappers.TimeLimit(env, steps_limit) return env +# Dictionary of env_name as key and a sub-dict containing env_type and a env-creator function configurations = { 'CartPole-v1' : { 'vecenv_type' : 'RAY', @@ -458,4 +459,11 @@ def get_obs_and_action_spaces_from_config(config): def register(name, config): + """Add a new key-value pair to the known environments (configurations dict). + + Args: + name (:obj:`str`): Name of the env to be added. + config (:obj:`dict`): Dictionary with env type and a creator function. + + """ configurations[name] = config \ No newline at end of file diff --git a/rl_games/common/object_factory.py b/rl_games/common/object_factory.py index 4cd97f98..45728098 100644 --- a/rl_games/common/object_factory.py +++ b/rl_games/common/object_factory.py @@ -1,14 +1,39 @@ class ObjectFactory: + """General-purpose class to instantiate some other base class from rl_games. Usual use case it to instantiate algos, players etc. + + The ObjectFactory class is used to dynamically create any other object using a builder function (typically a lambda function). + + """ + def __init__(self): + """Initialise a dictionary of builders with keys as `str` and values as functions. + + """ self._builders = {} def register_builder(self, name, builder): + """Register a passed builder by adding to the builders dict. + + Initialises runners and players for all algorithms available in the library using `rl_games.common.object_factory.ObjectFactory` + + Args: + name (:obj:`str`): Key of the added builder. + builder (:obj `func`): Function to return the requested object + + """ self._builders[name] = builder def set_builders(self, builders): self._builders = builders def create(self, name, **kwargs): + """Create the requested object by calling a registered builder function. + + Args: + name (:obj:`str`): Key of the requested builder. + **kwargs: Arbitrary kwargs needed for the builder function + + """ builder = self._builders.get(name) if not builder: raise ValueError(name) diff --git a/rl_games/common/vecenv.py b/rl_games/common/vecenv.py index 01016723..c29fd4be 100644 --- a/rl_games/common/vecenv.py +++ b/rl_games/common/vecenv.py @@ -8,7 +8,20 @@ import torch class RayWorker: + """Wrapper around a third-party (gym for example) environment class that enables parallel training. + + The RayWorker class wraps around another environment class to enable the use of this + environment within an asynchronous parallel training setup + + """ def __init__(self, config_name, config): + """Initialise the class. Sets up the environment creator using the `rl_games.common.env_configurations.configuraitons` dict + + Args: + config_name (:obj:`str`): Key of the environment to create. + config: Misc. kwargs passed on to the environment creator function + + """ self.env = configurations[config_name]['env_creator'](**config) def _obs_to_fp32(self, obs): @@ -27,6 +40,12 @@ def _obs_to_fp32(self, obs): return obs def step(self, action): + """Step the environment and reset if done + + Args: + action (type depends on env): Action to take. + + """ next_state, reward, is_done, info = self.env.step(action) if np.isscalar(is_done): @@ -95,9 +114,23 @@ def get_env_info(self): class RayVecEnv(IVecEnv): + """Main env class that manages several `rl_games.common.vecenv.Rayworker` objects for parallel training + + The RayVecEnv class manages a set of individual environments and wraps around the methods from RayWorker. + Each worker is executed asynchronously. + + """ import ray def __init__(self, config_name, num_actors, **kwargs): + """Initialise the class. Sets up the config for the environment and creates individual workers to manage. + + Args: + config_name (:obj:`str`): Key of the environment to create. + num_actors (:obj:`int`): Number of environments (actors) to create + **kwargs: Misc. kwargs passed on to the environment creator function within the RayWorker __init__ + + """ self.config_name = config_name self.num_actors = num_actors self.use_torch = False @@ -131,6 +164,14 @@ def __init__(self, config_name, num_actors, **kwargs): self.concat_func = np.concatenate def step(self, actions): + """Step all individual environments (using the created workers). + Returns a concatenated array of observations, rewards, done states, and infos if the env allows concatenation. + Else returns a nested dict. + + Args: + action (type depends on env): Action to take. + + """ newobs, newstates, newrewards, newdones, newinfos = [], [], [], [], [] res_obs = [] if self.num_agents == 1: @@ -218,6 +259,12 @@ def reset(self): vecenv_config = {} def register(config_name, func): + """Add an environment type (for example RayVecEnv) to the list of available types `rl_games.common.vecenv.vecenv_config` + Args: + config_name (:obj:`str`): Key of the environment to create. + func (:obj:`func`): Function that creates the environment + + """ vecenv_config[config_name] = func def create_vec_env(config_name, num_actors, **kwargs): diff --git a/rl_games/torch_runner.py b/rl_games/torch_runner.py index 83bb8c07..4377d29b 100644 --- a/rl_games/torch_runner.py +++ b/rl_games/torch_runner.py @@ -31,8 +31,24 @@ def _override_sigma(agent, args): class Runner: + """Runs training/inference (playing) procedures as per a given configuration. + + The Runner class provides a high-level API for instantiating agents for either training or playing + with an RL algorithm. It further logs training metrics. + + """ def __init__(self, algo_observer=None): + """Initialise the runner instance with algorithms and observers. + + Initialises runners and players for all algorithms available in the library using `rl_games.common.object_factory.ObjectFactory` + + Args: + algo_observer (:obj:`rl_games.common.algo_observer.AlgoObserver`, optional): Algorithm observer that logs training metrics. + Defaults to `rl_games.common.algo_observer.DefaultAlgoObserver` + + """ + self.algo_factory = object_factory.ObjectFactory() self.algo_factory.register_builder('a2c_continuous', lambda **kwargs : a2c_continuous.A2CAgent(**kwargs)) self.algo_factory.register_builder('a2c_discrete', lambda **kwargs : a2c_discrete.DiscreteA2CAgent(**kwargs)) @@ -55,6 +71,13 @@ def reset(self): pass def load_config(self, params): + """Loads passed config params. + + Args: + params (:obj:`dict`): Parameters passed in as a dict obtained from a yaml file or some other config format. + + """ + self.seed = params.get('seed', None) if self.seed is None: self.seed = int(time.time()) @@ -109,6 +132,12 @@ def load(self, yaml_config): self.load_config(params=self.default_config) def run_train(self, args): + """Run the training procedure from the algorithm passed in. + + Args: + args (:obj:`dict`): Train specific args passed in as a dict obtained from a yaml file or some other config format. + + """ print('Started to train') agent = self.algo_factory.create(self.algo_name, base_name='run', params=self.params) _restore(agent, args) @@ -116,6 +145,12 @@ def run_train(self, args): agent.train() def run_play(self, args): + """Run the inference procedure from the algorithm passed in. + + Args: + args (:obj:`dict`): Playing specific args passed in as a dict obtained from a yaml file or some other config format. + + """ print('Started to play') player = self.create_player() _restore(player, args) @@ -129,6 +164,12 @@ def reset(self): pass def run(self, args): + """Run either train/play depending on the args. + + Args: + args (:obj:`dict`): Args passed in as a dict obtained from a yaml file or some other config format. + + """ if args['train']: self.run_train(args) elif args['play']: