Skip to content

Commit

Permalink
Fixed torch_runner. Updated Brax colab. (#205)
Browse files Browse the repository at this point in the history
* Fixed torch_runner. Updated Brax colab.

* Fix.

* Fix.

* Final clean-up.
  • Loading branch information
ViktorM authored Sep 27, 2022
1 parent 0ca4544 commit f3e9c7f
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 511 deletions.
492 changes: 24 additions & 468 deletions notebooks/brax_training.ipynb

Large diffs are not rendered by default.

33 changes: 8 additions & 25 deletions notebooks/mujoco_envpool_training.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -44,33 +44,27 @@
"metadata": {},
"outputs": [],
"source": [
"!nvidia-smi -L"
"!pip show rl-games"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6qvHCGgpxrvZ"
},
"metadata": {},
"outputs": [],
"source": [
"%load_ext tensorboard"
"!nvidia-smi -L"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "GFv1FDtJyC0z",
"outputId": "4082ccf2-139d-415a-c832-8b39f622e899"
"id": "6qvHCGgpxrvZ"
},
"outputs": [],
"source": [
"!pip show rl-games"
"%load_ext tensorboard"
]
},
{
Expand Down Expand Up @@ -367,17 +361,6 @@
"%tensorboard --logdir 'runs/'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "fyvlWdM_abGR"
},
"outputs": [],
"source": [
"from rl_games.torch_runner import Runner"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -500,9 +483,10 @@
"outputs": [],
"source": [
"import yaml\n",
"from rl_games.torch_runner import Runner\n",
"\n",
"config = walker_config\n",
"config['params']['config']['full_experiment_name'] = 'mujoco'\n",
"config['params']['config']['full_experiment_name'] = 'Walker2d_mujoco'\n",
"config['params']['config']['max_epochs'] = 500\n",
"config['params']['config']['horizon_length'] = 512\n",
"config['params']['config']['num_actors'] = 8\n",
Expand Down Expand Up @@ -531,8 +515,7 @@
"config = player_walker_config\n",
"config['params']['config']['player']['render'] = False\n",
"config['params']['config']['player']['games_num'] = 2\n",
" \n",
"runner = Runner()\n",
"\n",
"runner.load(config)\n",
"agent = runner.create_player()\n",
"agent.restore('runs/mujoco/nn/Walker2d-v4.pth')"
Expand Down
25 changes: 13 additions & 12 deletions rl_games/torch_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,20 @@
import time
import numpy as np
import random
import copy
from copy import deepcopy
import torch
import yaml
#import yaml

from rl_games import envs
#from rl_games import envs
from rl_games.common import object_factory
from rl_games.common import env_configurations
from rl_games.common import experiment
from rl_games.common import tr_helpers

from rl_games.algos_torch import model_builder
from rl_games.algos_torch import a2c_continuous
from rl_games.algos_torch import a2c_discrete
from rl_games.algos_torch import players
from rl_games.common.algo_observer import DefaultAlgoObserver
from rl_games.algos_torch import sac_agent
import rl_games.networks


def _restore(agent, args):
if 'checkpoint' in args and args['checkpoint'] is not None and args['checkpoint'] !='':
Expand All @@ -33,6 +30,8 @@ def _override_sigma(agent, args):
net.sigma.fill_(float(args['sigma']))
else:
print('Print cannot set new sigma because fixed_sigma is False')


class Runner:
def __init__(self, algo_observer=None):
self.algo_factory = object_factory.ObjectFactory()
Expand All @@ -52,28 +51,29 @@ def __init__(self, algo_observer=None):
### it didnot help for lots for openai gym envs anyway :(
#torch.backends.cudnn.deterministic = True
#torch.use_deterministic_algorithms(True)

def reset(self):
pass

def load_config(self, params):
self.seed = params.get('seed', None)
if self.seed is None:
self.seed = int(time.time())

if params["config"].get('multi_gpu', False):
self.seed += int(os.getenv("LOCAL_RANK", "0"))
print(f"self.seed = {self.seed}")

self.algo_params = params['algo']
self.algo_name = self.algo_params['name']
self.exp_config = None
if self.seed:

if self.seed:
torch.manual_seed(self.seed)
torch.cuda.manual_seed_all(self.seed)
np.random.seed(self.seed)
random.seed(self.seed)

# deal with environment specific seed if applicable
if 'env_config' in params['config']:
if not 'seed' in params['config']['env_config']:
Expand All @@ -89,8 +89,9 @@ def load_config(self, params):
config['features']['observer'] = self.algo_observer
self.params = params

def load(self, yaml_conf):
self.default_config = yaml_conf['params']
def load(self, yaml_config):
config = deepcopy(yaml_config)
self.default_config = deepcopy(config['params'])
self.load_config(params=self.default_config)

def run_train(self, args):
Expand Down
12 changes: 6 additions & 6 deletions runner.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from distutils.util import strtobool
import numpy as np
import argparse, copy, os, yaml
import ray, signal
import ray

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
#import warnings
#warnings.filterwarnings("error")

if __name__ == '__main__':
ap = argparse.ArgumentParser()
ap.add_argument("--seed", type=int, default=0, required=False,
Expand All @@ -15,7 +16,7 @@
ap.add_argument("-p", "--play", required=False, help="play(test) network", action='store_true')
ap.add_argument("-c", "--checkpoint", required=False, help="path to checkpoint")
ap.add_argument("-f", "--file", required=True, help="path to config")
ap.add_argument("-na", "--num_actors", type=int, default=0, required=False,
ap.add_argument("-na", "--num_actors", type=int, default=0, required=False,
help="number of envs running in parallel, if larger than 0 will overwrite the value in yaml config")
ap.add_argument("-s", "--sigma", type=float, required=False, help="sets new sigma value in case if 'fixed_sigma: True' in yaml config")
ap.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
Expand All @@ -29,22 +30,21 @@

args = vars(ap.parse_args())
config_name = args['file']

print('Loading config: ', config_name)
with open(config_name, 'r') as stream:
config = yaml.safe_load(stream)

if args['num_actors'] > 0:
config['params']['config']['num_actors'] = args['num_actors']

if args['seed'] > 0:
config['params']['seed'] = args['seed']
config['params']['config']['env_config']['seed'] = args['seed']

from rl_games.torch_runner import Runner

ray.init(object_store_memory=1024*1024*1000)
#signal.signal(signal.SIGINT, exit_gracefully)

runner = Runner()
try:
Expand All @@ -68,6 +68,6 @@
runner.run(args)

ray.shutdown()

if args["track"] and rank == 0:
wandb.finish()

0 comments on commit f3e9c7f

Please sign in to comment.