From df1a63cf6b18a6b9a2cd4f7ae25306f6cfa189dc Mon Sep 17 00:00:00 2001 From: ViktorM Date: Thu, 21 Sep 2023 13:37:25 -0700 Subject: [PATCH 1/6] Ray is an optional dependency now. --- rl_games/common/vecenv.py | 4 +++- rl_games/configs/mujoco/ant.yaml | 8 ++++---- rl_games/torch_runner.py | 2 -- runner.py | 16 ++++++++++++---- setup.py | 2 -- 5 files changed, 19 insertions(+), 13 deletions(-) diff --git a/rl_games/common/vecenv.py b/rl_games/common/vecenv.py index d6261656..646da555 100644 --- a/rl_games/common/vecenv.py +++ b/rl_games/common/vecenv.py @@ -1,4 +1,3 @@ -import ray from rl_games.common.ivecenv import IVecEnv from rl_games.common.env_configurations import configurations from rl_games.common.tr_helpers import dicts_to_dict_with_arrays @@ -8,6 +7,7 @@ from time import sleep import torch + class RayWorker: def __init__(self, config_name, config): self.env = configurations[config_name]['env_creator'](**config) @@ -101,6 +101,8 @@ def __init__(self, config_name, num_actors, **kwargs): self.num_actors = num_actors self.use_torch = False self.seed = kwargs.pop('seed', None) + + import ray self.remote_worker = ray.remote(RayWorker) self.workers = [self.remote_worker.remote(self.config_name, kwargs) for i in range(self.num_actors)] diff --git a/rl_games/configs/mujoco/ant.yaml b/rl_games/configs/mujoco/ant.yaml index 4cda780a..2d7ad5bb 100644 --- a/rl_games/configs/mujoco/ant.yaml +++ b/rl_games/configs/mujoco/ant.yaml @@ -26,7 +26,7 @@ params: name: default config: - name: Ant-v4_ray + name: Ant-v3_ray env_name: openai_gym score_to_win: 20000 normalize_input: True @@ -46,8 +46,8 @@ params: truncate_grads: True e_clip: 0.2 max_epochs: 2000 - num_actors: 64 - horizon_length: 64 + num_actors: 8 #64 + horizon_length: 256 #64 minibatch_size: 2048 mini_epochs: 4 critic_coef: 2 @@ -57,7 +57,7 @@ params: bounds_loss_coef: 0.0 env_config: - name: Ant-v4 + name: Ant-v3 seed: 5 player: diff --git a/rl_games/torch_runner.py b/rl_games/torch_runner.py index e8ec70ee..5f8c2ac3 100644 --- a/rl_games/torch_runner.py +++ b/rl_games/torch_runner.py @@ -4,9 +4,7 @@ import random from copy import deepcopy import torch -#import yaml -#from rl_games import envs from rl_games.common import object_factory from rl_games.common import tr_helpers diff --git a/runner.py b/runner.py index 25f79af4..ed680855 100644 --- a/runner.py +++ b/runner.py @@ -1,6 +1,5 @@ from distutils.util import strtobool import argparse, os, yaml -import ray os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" @@ -42,7 +41,12 @@ from rl_games.torch_runner import Runner - ray.init(object_store_memory=1024*1024*1000) + try: + import ray + except ImportError: + pass + else: + ray.init(object_store_memory=1024*1024*1000) runner = Runner() try: @@ -50,7 +54,6 @@ except yaml.YAMLError as exc: print(exc) - #rank = int(os.getenv("LOCAL_RANK", "0")) global_rank = int(os.getenv("RANK", "0")) if args["track"] and global_rank == 0: import wandb @@ -66,7 +69,12 @@ runner.run(args) - ray.shutdown() + try: + import ray + except ImportError: + pass + else: + ray.shutdown() if args["track"] and global_rank == 0: wandb.finish() diff --git a/setup.py b/setup.py index 99c2ea82..d3c36193 100644 --- a/setup.py +++ b/setup.py @@ -39,8 +39,6 @@ 'gym>=0.17.2', 'torch>=1.7.0', 'numpy>=1.16.0', - # to support Python 3.10 - 'ray>=2.2.0', 'tensorboard>=1.14.0', 'tensorboardX>=1.6', 'setproctitle', From 61e998f9089d6680ae34481ea86e376cce2950ba Mon Sep 17 00:00:00 2001 From: ViktorM Date: Mon, 25 Sep 2023 16:14:05 -0700 Subject: [PATCH 2/6] seq_length clean up --- rl_games/algos_torch/a2c_continuous.py | 6 +++--- rl_games/algos_torch/a2c_discrete.py | 7 ++++--- rl_games/algos_torch/central_value.py | 14 +++++++------- rl_games/algos_torch/network_builder.py | 23 +++++++++++++++-------- rl_games/common/a2c_common.py | 10 +++++----- rl_games/common/datasets.py | 10 ++++++---- 6 files changed, 40 insertions(+), 30 deletions(-) diff --git a/rl_games/algos_torch/a2c_continuous.py b/rl_games/algos_torch/a2c_continuous.py index 285b8abf..b731a4ed 100644 --- a/rl_games/algos_torch/a2c_continuous.py +++ b/rl_games/algos_torch/a2c_continuous.py @@ -40,7 +40,7 @@ def __init__(self, base_name, params): 'horizon_length' : self.horizon_length, 'num_actors' : self.num_actors, 'num_actions' : self.actions_num, - 'seq_len' : self.seq_len, + 'seq_length' : self.seq_length, 'normalize_value' : self.normalize_value, 'network' : self.central_value_config['network'], 'config' : self.central_value_config, @@ -52,7 +52,7 @@ def __init__(self, base_name, params): self.central_value_net = central_value.CentralValueTrain(**cv_config).to(self.ppo_device) self.use_experimental_cv = self.config.get('use_experimental_cv', True) - self.dataset = datasets.PPODataset(self.batch_size, self.minibatch_size, self.is_discrete, self.is_rnn, self.ppo_device, self.seq_len) + self.dataset = datasets.PPODataset(self.batch_size, self.minibatch_size, self.is_discrete, self.is_rnn, self.ppo_device, self.seq_length) if self.normalize_value: self.value_mean_std = self.central_value_net.model.value_mean_std if self.has_central_value else self.model.value_mean_std @@ -98,7 +98,7 @@ def calc_gradients(self, input_dict): if self.is_rnn: rnn_masks = input_dict['rnn_masks'] batch_dict['rnn_states'] = input_dict['rnn_states'] - batch_dict['seq_length'] = self.seq_len + batch_dict['seq_length'] = self.seq_length if self.zero_rnn_on_done: batch_dict['dones'] = input_dict['dones'] diff --git a/rl_games/algos_torch/a2c_discrete.py b/rl_games/algos_torch/a2c_discrete.py index d386cd8e..fc1bda89 100644 --- a/rl_games/algos_torch/a2c_discrete.py +++ b/rl_games/algos_torch/a2c_discrete.py @@ -43,7 +43,7 @@ def __init__(self, base_name, params): 'horizon_length' : self.horizon_length, 'num_actors' : self.num_actors, 'num_actions' : self.actions_num, - 'seq_len' : self.seq_len, + 'seq_length' : self.seq_length, 'normalize_value' : self.normalize_value, 'network' : self.central_value_config['network'], 'config' : self.central_value_config, @@ -55,7 +55,7 @@ def __init__(self, base_name, params): self.central_value_net = central_value.CentralValueTrain(**cv_config).to(self.ppo_device) self.use_experimental_cv = self.config.get('use_experimental_cv', False) - self.dataset = datasets.PPODataset(self.batch_size, self.minibatch_size, self.is_discrete, self.is_rnn, self.ppo_device, self.seq_len) + self.dataset = datasets.PPODataset(self.batch_size, self.minibatch_size, self.is_discrete, self.is_rnn, self.ppo_device, self.seq_length) if self.normalize_value: self.value_mean_std = self.central_value_net.model.value_mean_std if self.has_central_value else self.model.value_mean_std @@ -127,11 +127,12 @@ def calc_gradients(self, input_dict): } if self.use_action_masks: batch_dict['action_masks'] = input_dict['action_masks'] + rnn_masks = None if self.is_rnn: rnn_masks = input_dict['rnn_masks'] batch_dict['rnn_states'] = input_dict['rnn_states'] - batch_dict['seq_length'] = self.seq_len + batch_dict['seq_length'] = self.seq_length batch_dict['bptt_len'] = self.bptt_len if self.zero_rnn_on_done: batch_dict['dones'] = input_dict['dones'] diff --git a/rl_games/algos_torch/central_value.py b/rl_games/algos_torch/central_value.py index 292ffd09..ecb95bcc 100644 --- a/rl_games/algos_torch/central_value.py +++ b/rl_games/algos_torch/central_value.py @@ -16,7 +16,7 @@ def __init__(self, state_shape, value_size, ppo_device, num_agents, horizon_leng nn.Module.__init__(self) self.ppo_device = ppo_device - self.num_agents, self.horizon_length, self.num_actors, self.seq_len = num_agents, horizon_length, num_actors, seq_len + self.num_agents, self.horizon_length, self.num_actors, self.seq_length = num_agents, horizon_length, num_actors, seq_length self.normalize_value = normalize_value self.num_actions = num_actions self.state_shape = state_shape @@ -78,8 +78,8 @@ def __init__(self, state_shape, value_size, ppo_device, num_agents, horizon_leng self.rnn_states = self.model.get_default_rnn_state() self.rnn_states = [s.to(self.ppo_device) for s in self.rnn_states] total_agents = self.num_actors #* self.num_agents - num_seqs = self.horizon_length // self.seq_len - assert ((self.horizon_length * total_agents // self.num_minibatches) % self.seq_len == 0) + num_seqs = self.horizon_length // self.seq_length + assert ((self.horizon_length * total_agents // self.num_minibatches) % self.seq_length == 0) self.mb_rnn_states = [ torch.zeros((num_seqs, s.size()[0], total_agents, s.size()[2]), dtype=torch.float32, device=self.ppo_device) for s in self.rnn_states] self.local_rank = 0 @@ -100,7 +100,7 @@ def __init__(self, state_shape, value_size, ppo_device, num_agents, horizon_leng config['print_stats'] = False config['lr_schedule'] = None - self.dataset = datasets.PPODataset(self.batch_size, self.minibatch_size, True, self.is_rnn, self.ppo_device, self.seq_len) + self.dataset = datasets.PPODataset(self.batch_size, self.minibatch_size, True, self.is_rnn, self.ppo_device, self.seq_length) def update_lr(self, lr): if self.multi_gpu: @@ -167,9 +167,9 @@ def _preproc_obs(self, obs_batch): def pre_step_rnn(self, n): if not self.is_rnn: return - if n % self.seq_len == 0: + if n % self.seq_length == 0: for s, mb_s in zip(self.rnn_states, self.mb_rnn_states): - mb_s[n // self.seq_len,:,:,:] = s + mb_s[n // self.seq_length,:,:,:] = s def post_step_rnn(self, all_done_indices, zero_rnn_on_done=True): if not self.is_rnn: @@ -245,7 +245,7 @@ def calc_gradients(self, batch): batch_dict = {'obs' : obs_batch, 'actions' : actions_batch, - 'seq_length' : self.seq_len, + 'seq_length' : self.seq_length, 'dones' : dones_batch} if self.is_rnn: batch_dict['rnn_states'] = batch['rnn_states'] diff --git a/rl_games/algos_torch/network_builder.py b/rl_games/algos_torch/network_builder.py index 86287f49..26434027 100644 --- a/rl_games/algos_torch/network_builder.py +++ b/rl_games/algos_torch/network_builder.py @@ -3,20 +3,17 @@ import torch import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim -import math -import numpy as np from rl_games.algos_torch.d2rl import D2RLNet from rl_games.algos_torch.sac_helper import SquashedNormal from rl_games.common.layers.recurrent import GRUWithDones, LSTMWithDones from rl_games.common.layers.value import TwoHotEncodedValue, DefaultValue -from rl_games.algos_torch.layers import symexp, symlog + def _create_initializer(func, **kwargs): return lambda v : func(v, **kwargs) + class NetworkBuilder: def __init__(self, **kwargs): pass @@ -196,6 +193,7 @@ def __init__(self, params, **kwargs): input_shape = kwargs.pop('input_shape') self.value_size = kwargs.pop('value_size', 1) self.num_seqs = num_seqs = kwargs.pop('num_seqs', 1) + NetworkBuilder.BaseNetwork.__init__(self) self.load(params) self.actor_cnn = nn.Sequential() @@ -306,9 +304,9 @@ def __init__(self, params, **kwargs): def forward(self, obs_dict): obs = obs_dict['obs'] states = obs_dict.get('rnn_states', None) - seq_length = obs_dict.get('seq_length', 1) dones = obs_dict.get('dones', None) bptt_len = obs_dict.get('bptt_len', 0) + if self.has_cnn: # for obs shape 4 # input expected shape (B, W, H, C) @@ -325,6 +323,8 @@ def forward(self, obs_dict): c_out = c_out.contiguous().view(c_out.size(0), -1) if self.has_rnn: + seq_length = obs_dict.get['seq_length'] + if not self.is_rnn_before_mlp: a_out_in = a_out c_out_in = c_out @@ -398,6 +398,8 @@ def forward(self, obs_dict): out = out.flatten(1) if self.has_rnn: + seq_length = obs_dict.get['seq_length'] + out_in = out if not self.is_rnn_before_mlp: out_in = out @@ -703,13 +705,15 @@ def forward(self, obs_dict): dones = obs_dict.get('dones', None) bptt_len = obs_dict.get('bptt_len', 0) states = obs_dict.get('rnn_states', None) - seq_length = obs_dict.get('seq_length', 1) + out = obs out = self.cnn(out) out = out.flatten(1) out = self.flatten_act(out) if self.has_rnn: + seq_length = obs_dict.get['seq_length'] + out_in = out if not self.is_rnn_before_mlp: out_in = out @@ -769,13 +773,15 @@ def load(self, params): self.is_multi_discrete = 'multi_discrete'in params['space'] self.value_activation = params.get('value_activation', 'None') self.normalization = params.get('normalization', None) + if self.is_continuous: self.space_config = params['space']['continuous'] self.fixed_sigma = self.space_config['fixed_sigma'] elif self.is_discrete: self.space_config = params['space']['discrete'] elif self.is_multi_discrete: - self.space_config = params['space']['multi_discrete'] + self.space_config = params['space']['multi_discrete'] + self.has_rnn = 'rnn' in params if self.has_rnn: self.rnn_units = params['rnn']['units'] @@ -783,6 +789,7 @@ def load(self, params): self.rnn_name = params['rnn']['name'] self.is_rnn_before_mlp = params['rnn'].get('before_mlp', False) self.rnn_ln = params['rnn'].get('layer_norm', False) + self.has_cnn = True self.permute_input = params['cnn'].get('permute_input', True) self.conv_depths = params['cnn']['conv_depths'] diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index 646dd809..995d98e9 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -202,8 +202,8 @@ def __init__(self, base_name, params): self.rewards_shaper = config['reward_shaper'] self.num_agents = self.env_info.get('agents', 1) self.horizon_length = config['horizon_length'] - self.seq_len = self.config.get('seq_length', 4) - self.bptt_len = self.config.get('bptt_length', self.seq_len) # not used right now. Didn't show that it is usefull + self.seq_length = self.config.get('seq_length', 4) + self.bptt_len = self.config.get('bptt_length', self.seq_length) # not used right now. Didn't show that it is usefull self.zero_rnn_on_done = self.config.get('zero_rnn_on_done', True) self.normalize_advantage = config['normalize_advantage'] self.normalize_rms_advantage = config.get('normalize_rms_advantage', False) @@ -229,7 +229,7 @@ def __init__(self, base_name, params): self.game_shaped_rewards = torch_ext.AverageMeter(self.value_size, self.games_to_track).to(self.ppo_device) self.game_lengths = torch_ext.AverageMeter(1, self.games_to_track).to(self.ppo_device) self.obs = None - self.games_num = self.config['minibatch_size'] // self.seq_len # it is used only for current rnn implementation + self.games_num = self.config['minibatch_size'] // self.seq_length # it is used only for current rnn implementation self.batch_size = self.horizon_length * self.num_actors * self.num_agents self.batch_size_envs = self.horizon_length * self.num_actors @@ -464,7 +464,7 @@ def init_tensors(self): total_agents = self.num_agents * self.num_actors num_seqs = self.horizon_length // self.seq_len - assert((self.horizon_length * total_agents // self.num_minibatches) % self.seq_len == 0) + assert((self.horizon_length * total_agents // self.num_minibatches) % self.seq_length == 0) self.mb_rnn_states = [torch.zeros((num_seqs, s.size()[0], total_agents, s.size()[2]), dtype = torch.float32, device=self.ppo_device) for s in self.rnn_states] def init_rnn_from_model(self, model): @@ -792,7 +792,7 @@ def play_steps_rnn(self): step_time = 0.0 for n in range(self.horizon_length): - if n % self.seq_len == 0: + if n % self.seq_length == 0: for s, mb_s in zip(self.rnn_states, mb_rnn_states): mb_s[n // self.seq_len,:,:,:] = s diff --git a/rl_games/common/datasets.py b/rl_games/common/datasets.py index a2b6c14f..56e7335c 100644 --- a/rl_games/common/datasets.py +++ b/rl_games/common/datasets.py @@ -2,20 +2,22 @@ import copy from torch.utils.data import Dataset + class PPODataset(Dataset): - def __init__(self, batch_size, minibatch_size, is_discrete, is_rnn, device, seq_len): + + def __init__(self, batch_size, minibatch_size, is_discrete, is_rnn, device, seq_length): self.is_rnn = is_rnn - self.seq_len = seq_len + self.seq_length = seq_length self.batch_size = batch_size self.minibatch_size = minibatch_size self.device = device self.length = self.batch_size // self.minibatch_size self.is_discrete = is_discrete self.is_continuous = not is_discrete - total_games = self.batch_size // self.seq_len + total_games = self.batch_size // self.seq_length self.num_games_batch = self.minibatch_size // self.seq_len self.game_indexes = torch.arange(total_games, dtype=torch.long, device=self.device) - self.flat_indexes = torch.arange(total_games * self.seq_len, dtype=torch.long, device=self.device).reshape(total_games, self.seq_len) + self.flat_indexes = torch.arange(total_games * self.seq_len, dtype=torch.long, device=self.device).reshape(total_games, self.seq_length) self.special_names = ['rnn_states'] From 06756a0fdd73dfa9c6c02225365f503f06c233a1 Mon Sep 17 00:00:00 2001 From: ViktorM Date: Mon, 25 Sep 2023 16:19:53 -0700 Subject: [PATCH 3/6] More seq_length work. --- rl_games/algos_torch/central_value.py | 2 +- rl_games/common/a2c_common.py | 5 ++++- rl_games/common/datasets.py | 11 ++++++----- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/rl_games/algos_torch/central_value.py b/rl_games/algos_torch/central_value.py index ecb95bcc..ac91e237 100644 --- a/rl_games/algos_torch/central_value.py +++ b/rl_games/algos_torch/central_value.py @@ -12,7 +12,7 @@ class CentralValueTrain(nn.Module): def __init__(self, state_shape, value_size, ppo_device, num_agents, horizon_length, num_actors, num_actions, - seq_len, normalize_value, network, config, writter, max_epochs, multi_gpu, zero_rnn_on_done): + seq_length, normalize_value, network, config, writter, max_epochs, multi_gpu, zero_rnn_on_done): nn.Module.__init__(self) self.ppo_device = ppo_device diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index 995d98e9..a3c53df7 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -202,7 +202,10 @@ def __init__(self, base_name, params): self.rewards_shaper = config['reward_shaper'] self.num_agents = self.env_info.get('agents', 1) self.horizon_length = config['horizon_length'] + + # seq_length is used only with rnn policy and value functions self.seq_length = self.config.get('seq_length', 4) + print('seq_length:', self.seq_length) self.bptt_len = self.config.get('bptt_length', self.seq_length) # not used right now. Didn't show that it is usefull self.zero_rnn_on_done = self.config.get('zero_rnn_on_done', True) self.normalize_advantage = config['normalize_advantage'] @@ -794,7 +797,7 @@ def play_steps_rnn(self): for n in range(self.horizon_length): if n % self.seq_length == 0: for s, mb_s in zip(self.rnn_states, mb_rnn_states): - mb_s[n // self.seq_len,:,:,:] = s + mb_s[n // self.seq_length,:,:,:] = s if self.has_central_value: self.central_value_net.pre_step_rnn(n) diff --git a/rl_games/common/datasets.py b/rl_games/common/datasets.py index 56e7335c..13f76e05 100644 --- a/rl_games/common/datasets.py +++ b/rl_games/common/datasets.py @@ -15,9 +15,9 @@ def __init__(self, batch_size, minibatch_size, is_discrete, is_rnn, device, seq_ self.is_discrete = is_discrete self.is_continuous = not is_discrete total_games = self.batch_size // self.seq_length - self.num_games_batch = self.minibatch_size // self.seq_len + self.num_games_batch = self.minibatch_size // self.seq_length self.game_indexes = torch.arange(total_games, dtype=torch.long, device=self.device) - self.flat_indexes = torch.arange(total_games * self.seq_len, dtype=torch.long, device=self.device).reshape(total_games, self.seq_length) + self.flat_indexes = torch.arange(total_games * self.seq_length, dtype=torch.long, device=self.device).reshape(total_games, self.seq_length) self.special_names = ['rnn_states'] @@ -36,9 +36,10 @@ def __len__(self): def _get_item_rnn(self, idx): gstart = idx * self.num_games_batch gend = (idx + 1) * self.num_games_batch - start = gstart * self.seq_len - end = gend * self.seq_len - self.last_range = (start, end) + start = gstart * self.seq_length + end = gend * self.seq_length + self.last_range = (start, end) + input_dict = {} for k,v in self.values_dict.items(): if k not in self.special_names: From 872e767a50e82629bb140711d27c111688112504 Mon Sep 17 00:00:00 2001 From: ViktorM Date: Mon, 25 Sep 2023 16:27:25 -0700 Subject: [PATCH 4/6] Fixes. --- README.md | 1 + rl_games/algos_torch/network_builder.py | 6 +++--- rl_games/common/a2c_common.py | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 6410f09d..7393c8ff 100644 --- a/README.md +++ b/README.md @@ -305,6 +305,7 @@ Additional environment supported properties and functions * Added evaluation feature for inferencing during training. Checkpoints from training process can be automatically picked up and updated in the inferencing process when enabled. * Added get/set API for runtime update of rl training parameters. Thanks to @ArthurAllshire for the initial version of fast PBT code. * Fixed SAC not loading weights properly. +* Removed Ray dependency for use cases it's not required. 1.6.0 diff --git a/rl_games/algos_torch/network_builder.py b/rl_games/algos_torch/network_builder.py index 26434027..73447607 100644 --- a/rl_games/algos_torch/network_builder.py +++ b/rl_games/algos_torch/network_builder.py @@ -323,7 +323,7 @@ def forward(self, obs_dict): c_out = c_out.contiguous().view(c_out.size(0), -1) if self.has_rnn: - seq_length = obs_dict.get['seq_length'] + seq_length = obs_dict['seq_length'] if not self.is_rnn_before_mlp: a_out_in = a_out @@ -398,7 +398,7 @@ def forward(self, obs_dict): out = out.flatten(1) if self.has_rnn: - seq_length = obs_dict.get['seq_length'] + seq_length = obs_dict['seq_length'] out_in = out if not self.is_rnn_before_mlp: @@ -712,7 +712,7 @@ def forward(self, obs_dict): out = self.flatten_act(out) if self.has_rnn: - seq_length = obs_dict.get['seq_length'] + seq_length = obs_dict['seq_length'] out_in = out if not self.is_rnn_before_mlp: diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index a3c53df7..b1cda019 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -466,7 +466,7 @@ def init_tensors(self): self.rnn_states = [s.to(self.ppo_device) for s in self.rnn_states] total_agents = self.num_agents * self.num_actors - num_seqs = self.horizon_length // self.seq_len + num_seqs = self.horizon_length // self.seq_length assert((self.horizon_length * total_agents // self.num_minibatches) % self.seq_length == 0) self.mb_rnn_states = [torch.zeros((num_seqs, s.size()[0], total_agents, s.size()[2]), dtype = torch.float32, device=self.ppo_device) for s in self.rnn_states] From 94e0ce1ef216c5a30fb6ce4873be0708a797e9bc Mon Sep 17 00:00:00 2001 From: ViktorM Date: Mon, 25 Sep 2023 21:39:26 -0700 Subject: [PATCH 5/6] Temporary reverted back seq_length network builder change. --- rl_games/algos_torch/central_value.py | 3 +-- rl_games/algos_torch/network_builder.py | 9 ++++++--- rl_games/common/a2c_common.py | 4 ++++ rl_games/common/datasets.py | 1 + 4 files changed, 12 insertions(+), 5 deletions(-) diff --git a/rl_games/algos_torch/central_value.py b/rl_games/algos_torch/central_value.py index ac91e237..d75c687c 100644 --- a/rl_games/algos_torch/central_value.py +++ b/rl_games/algos_torch/central_value.py @@ -183,7 +183,6 @@ def post_step_rnn(self, all_done_indices, zero_rnn_on_done=True): def forward(self, input_dict): return self.model(input_dict) - def get_value(self, input_dict): self.eval() obs_batch = input_dict['states'] @@ -284,5 +283,5 @@ def calc_gradients(self, batch): nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm) self.optimizer.step() - + return loss diff --git a/rl_games/algos_torch/network_builder.py b/rl_games/algos_torch/network_builder.py index 73447607..ce163e48 100644 --- a/rl_games/algos_torch/network_builder.py +++ b/rl_games/algos_torch/network_builder.py @@ -323,7 +323,8 @@ def forward(self, obs_dict): c_out = c_out.contiguous().view(c_out.size(0), -1) if self.has_rnn: - seq_length = obs_dict['seq_length'] + #seq_length = obs_dict['seq_length'] + seq_length = obs_dict.get('seq_length', 1) if not self.is_rnn_before_mlp: a_out_in = a_out @@ -398,7 +399,8 @@ def forward(self, obs_dict): out = out.flatten(1) if self.has_rnn: - seq_length = obs_dict['seq_length'] + #seq_length = obs_dict['seq_length'] + seq_length = obs_dict.get('seq_length', 1) out_in = out if not self.is_rnn_before_mlp: @@ -712,7 +714,8 @@ def forward(self, obs_dict): out = self.flatten_act(out) if self.has_rnn: - seq_length = obs_dict['seq_length'] + #seq_length = obs_dict['seq_length'] + seq_length = obs_dict.get('seq_length', 1) out_in = out if not self.is_rnn_before_mlp: diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index b1cda019..9c4f1981 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -807,6 +807,7 @@ def play_steps_rnn(self): res_dict = self.get_masked_action_values(self.obs, masks) else: res_dict = self.get_action_values(self.obs) + self.rnn_states = res_dict['rnn_states'] self.experience_buffer.update_data('obses', n, self.obs['obs']) self.experience_buffer.update_data('dones', n, self.dones.byte()) @@ -863,6 +864,7 @@ def play_steps_rnn(self): mb_advs = self.discount_values(fdones, last_values, mb_fdones, mb_values, mb_rewards) mb_returns = mb_advs + mb_values batch_dict = self.experience_buffer.get_transformed_list(swap_and_flatten01, self.tensor_list) + batch_dict['returns'] = swap_and_flatten01(mb_returns) batch_dict['played_frames'] = self.batch_size states = [] @@ -870,8 +872,10 @@ def play_steps_rnn(self): t_size = mb_s.size()[0] * mb_s.size()[2] h_size = mb_s.size()[3] states.append(mb_s.permute(1,2,0,3).reshape(-1,t_size, h_size)) + batch_dict['rnn_states'] = states batch_dict['step_time'] = step_time + return batch_dict diff --git a/rl_games/common/datasets.py b/rl_games/common/datasets.py index 13f76e05..3a48f3cf 100644 --- a/rl_games/common/datasets.py +++ b/rl_games/common/datasets.py @@ -6,6 +6,7 @@ class PPODataset(Dataset): def __init__(self, batch_size, minibatch_size, is_discrete, is_rnn, device, seq_length): + self.is_rnn = is_rnn self.seq_length = seq_length self.batch_size = batch_size From c645d9a403a3b838c8e7cc217e8eec3274bdf740 Mon Sep 17 00:00:00 2001 From: ViktorM Date: Tue, 26 Sep 2023 07:52:02 -0700 Subject: [PATCH 6/6] Readme update. --- README.md | 2 ++ pyproject.toml | 1 - rl_games/algos_torch/network_builder.py | 4 ++-- rl_games/common/a2c_common.py | 4 ++++ 4 files changed, 8 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 7393c8ff..f621ae19 100644 --- a/README.md +++ b/README.md @@ -306,6 +306,8 @@ Additional environment supported properties and functions * Added get/set API for runtime update of rl training parameters. Thanks to @ArthurAllshire for the initial version of fast PBT code. * Fixed SAC not loading weights properly. * Removed Ray dependency for use cases it's not required. +* Added warning for using deprecated 'seq_len' instead of 'seq_length' in configs with RNN networks. + 1.6.0 diff --git a/pyproject.toml b/pyproject.toml index bed33c78..e73c4c42 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,6 @@ tensorboardX = "^2.5" PyYAML = "^6.0" psutil = "^5.9.0" setproctitle = "^1.2.2" -ray = "^1.11.0" opencv-python = "^4.5.5" wandb = "^0.12.11" diff --git a/rl_games/algos_torch/network_builder.py b/rl_games/algos_torch/network_builder.py index ce163e48..ce5651c5 100644 --- a/rl_games/algos_torch/network_builder.py +++ b/rl_games/algos_torch/network_builder.py @@ -323,7 +323,6 @@ def forward(self, obs_dict): c_out = c_out.contiguous().view(c_out.size(0), -1) if self.has_rnn: - #seq_length = obs_dict['seq_length'] seq_length = obs_dict.get('seq_length', 1) if not self.is_rnn_before_mlp: @@ -360,9 +359,11 @@ def forward(self, obs_dict): c_out = c_out.transpose(0,1) a_out = a_out.contiguous().reshape(a_out.size()[0] * a_out.size()[1], -1) c_out = c_out.contiguous().reshape(c_out.size()[0] * c_out.size()[1], -1) + if self.rnn_ln: a_out = self.a_layer_norm(a_out) c_out = self.c_layer_norm(c_out) + if type(a_states) is not tuple: a_states = (a_states,) c_states = (c_states,) @@ -399,7 +400,6 @@ def forward(self, obs_dict): out = out.flatten(1) if self.has_rnn: - #seq_length = obs_dict['seq_length'] seq_length = obs_dict.get('seq_length', 1) out_in = out diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index 9c4f1981..63b90c07 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -204,10 +204,14 @@ def __init__(self, base_name, params): self.horizon_length = config['horizon_length'] # seq_length is used only with rnn policy and value functions + if 'seq_len' in config: + print('WARNING: seq_len is deprecated, use seq_length instead') + self.seq_length = self.config.get('seq_length', 4) print('seq_length:', self.seq_length) self.bptt_len = self.config.get('bptt_length', self.seq_length) # not used right now. Didn't show that it is usefull self.zero_rnn_on_done = self.config.get('zero_rnn_on_done', True) + self.normalize_advantage = config['normalize_advantage'] self.normalize_rms_advantage = config.get('normalize_rms_advantage', False) self.normalize_input = self.config['normalize_input']