Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ray dependency and seq_length improvements #253

Merged
merged 7 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,9 @@ Additional environment supported properties and functions
* Added evaluation feature for inferencing during training. Checkpoints from training process can be automatically picked up and updated in the inferencing process when enabled.
* Added get/set API for runtime update of rl training parameters. Thanks to @ArthurAllshire for the initial version of fast PBT code.
* Fixed SAC not loading weights properly.
* Removed Ray dependency for use cases it's not required.
* Added warning for using deprecated 'seq_len' instead of 'seq_length' in configs with RNN networks.


1.6.0

Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ tensorboardX = "^2.5"
PyYAML = "^6.0"
psutil = "^5.9.0"
setproctitle = "^1.2.2"
ray = "^1.11.0"
opencv-python = "^4.5.5"
wandb = "^0.12.11"

Expand Down
6 changes: 3 additions & 3 deletions rl_games/algos_torch/a2c_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, base_name, params):
'horizon_length' : self.horizon_length,
'num_actors' : self.num_actors,
'num_actions' : self.actions_num,
'seq_len' : self.seq_len,
'seq_length' : self.seq_length,
'normalize_value' : self.normalize_value,
'network' : self.central_value_config['network'],
'config' : self.central_value_config,
Expand All @@ -52,7 +52,7 @@ def __init__(self, base_name, params):
self.central_value_net = central_value.CentralValueTrain(**cv_config).to(self.ppo_device)

self.use_experimental_cv = self.config.get('use_experimental_cv', True)
self.dataset = datasets.PPODataset(self.batch_size, self.minibatch_size, self.is_discrete, self.is_rnn, self.ppo_device, self.seq_len)
self.dataset = datasets.PPODataset(self.batch_size, self.minibatch_size, self.is_discrete, self.is_rnn, self.ppo_device, self.seq_length)
if self.normalize_value:
self.value_mean_std = self.central_value_net.model.value_mean_std if self.has_central_value else self.model.value_mean_std

Expand Down Expand Up @@ -98,7 +98,7 @@ def calc_gradients(self, input_dict):
if self.is_rnn:
rnn_masks = input_dict['rnn_masks']
batch_dict['rnn_states'] = input_dict['rnn_states']
batch_dict['seq_length'] = self.seq_len
batch_dict['seq_length'] = self.seq_length

if self.zero_rnn_on_done:
batch_dict['dones'] = input_dict['dones']
Expand Down
7 changes: 4 additions & 3 deletions rl_games/algos_torch/a2c_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(self, base_name, params):
'horizon_length' : self.horizon_length,
'num_actors' : self.num_actors,
'num_actions' : self.actions_num,
'seq_len' : self.seq_len,
'seq_length' : self.seq_length,
'normalize_value' : self.normalize_value,
'network' : self.central_value_config['network'],
'config' : self.central_value_config,
Expand All @@ -55,7 +55,7 @@ def __init__(self, base_name, params):
self.central_value_net = central_value.CentralValueTrain(**cv_config).to(self.ppo_device)

self.use_experimental_cv = self.config.get('use_experimental_cv', False)
self.dataset = datasets.PPODataset(self.batch_size, self.minibatch_size, self.is_discrete, self.is_rnn, self.ppo_device, self.seq_len)
self.dataset = datasets.PPODataset(self.batch_size, self.minibatch_size, self.is_discrete, self.is_rnn, self.ppo_device, self.seq_length)

if self.normalize_value:
self.value_mean_std = self.central_value_net.model.value_mean_std if self.has_central_value else self.model.value_mean_std
Expand Down Expand Up @@ -127,11 +127,12 @@ def calc_gradients(self, input_dict):
}
if self.use_action_masks:
batch_dict['action_masks'] = input_dict['action_masks']

rnn_masks = None
if self.is_rnn:
rnn_masks = input_dict['rnn_masks']
batch_dict['rnn_states'] = input_dict['rnn_states']
batch_dict['seq_length'] = self.seq_len
batch_dict['seq_length'] = self.seq_length
batch_dict['bptt_len'] = self.bptt_len
if self.zero_rnn_on_done:
batch_dict['dones'] = input_dict['dones']
Expand Down
19 changes: 9 additions & 10 deletions rl_games/algos_torch/central_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
class CentralValueTrain(nn.Module):

def __init__(self, state_shape, value_size, ppo_device, num_agents, horizon_length, num_actors, num_actions,
seq_len, normalize_value, network, config, writter, max_epochs, multi_gpu, zero_rnn_on_done):
seq_length, normalize_value, network, config, writter, max_epochs, multi_gpu, zero_rnn_on_done):
nn.Module.__init__(self)

self.ppo_device = ppo_device
self.num_agents, self.horizon_length, self.num_actors, self.seq_len = num_agents, horizon_length, num_actors, seq_len
self.num_agents, self.horizon_length, self.num_actors, self.seq_length = num_agents, horizon_length, num_actors, seq_length
self.normalize_value = normalize_value
self.num_actions = num_actions
self.state_shape = state_shape
Expand Down Expand Up @@ -78,8 +78,8 @@ def __init__(self, state_shape, value_size, ppo_device, num_agents, horizon_leng
self.rnn_states = self.model.get_default_rnn_state()
self.rnn_states = [s.to(self.ppo_device) for s in self.rnn_states]
total_agents = self.num_actors #* self.num_agents
num_seqs = self.horizon_length // self.seq_len
assert ((self.horizon_length * total_agents // self.num_minibatches) % self.seq_len == 0)
num_seqs = self.horizon_length // self.seq_length
assert ((self.horizon_length * total_agents // self.num_minibatches) % self.seq_length == 0)
self.mb_rnn_states = [ torch.zeros((num_seqs, s.size()[0], total_agents, s.size()[2]), dtype=torch.float32, device=self.ppo_device) for s in self.rnn_states]

self.local_rank = 0
Expand All @@ -100,7 +100,7 @@ def __init__(self, state_shape, value_size, ppo_device, num_agents, horizon_leng
config['print_stats'] = False
config['lr_schedule'] = None

self.dataset = datasets.PPODataset(self.batch_size, self.minibatch_size, True, self.is_rnn, self.ppo_device, self.seq_len)
self.dataset = datasets.PPODataset(self.batch_size, self.minibatch_size, True, self.is_rnn, self.ppo_device, self.seq_length)

def update_lr(self, lr):
if self.multi_gpu:
Expand Down Expand Up @@ -167,9 +167,9 @@ def _preproc_obs(self, obs_batch):
def pre_step_rnn(self, n):
if not self.is_rnn:
return
if n % self.seq_len == 0:
if n % self.seq_length == 0:
for s, mb_s in zip(self.rnn_states, self.mb_rnn_states):
mb_s[n // self.seq_len,:,:,:] = s
mb_s[n // self.seq_length,:,:,:] = s

def post_step_rnn(self, all_done_indices, zero_rnn_on_done=True):
if not self.is_rnn:
Expand All @@ -183,7 +183,6 @@ def post_step_rnn(self, all_done_indices, zero_rnn_on_done=True):
def forward(self, input_dict):
return self.model(input_dict)


def get_value(self, input_dict):
self.eval()
obs_batch = input_dict['states']
Expand Down Expand Up @@ -245,7 +244,7 @@ def calc_gradients(self, batch):

batch_dict = {'obs' : obs_batch,
'actions' : actions_batch,
'seq_length' : self.seq_len,
'seq_length' : self.seq_length,
'dones' : dones_batch}
if self.is_rnn:
batch_dict['rnn_states'] = batch['rnn_states']
Expand Down Expand Up @@ -284,5 +283,5 @@ def calc_gradients(self, batch):
nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm)

self.optimizer.step()

return loss
26 changes: 18 additions & 8 deletions rl_games/algos_torch/network_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,17 @@

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import math
import numpy as np
from rl_games.algos_torch.d2rl import D2RLNet
from rl_games.algos_torch.sac_helper import SquashedNormal
from rl_games.common.layers.recurrent import GRUWithDones, LSTMWithDones
from rl_games.common.layers.value import TwoHotEncodedValue, DefaultValue
from rl_games.algos_torch.layers import symexp, symlog


def _create_initializer(func, **kwargs):
return lambda v : func(v, **kwargs)


class NetworkBuilder:
def __init__(self, **kwargs):
pass
Expand Down Expand Up @@ -196,6 +193,7 @@ def __init__(self, params, **kwargs):
input_shape = kwargs.pop('input_shape')
self.value_size = kwargs.pop('value_size', 1)
self.num_seqs = num_seqs = kwargs.pop('num_seqs', 1)

NetworkBuilder.BaseNetwork.__init__(self)
self.load(params)
self.actor_cnn = nn.Sequential()
Expand Down Expand Up @@ -306,9 +304,9 @@ def __init__(self, params, **kwargs):
def forward(self, obs_dict):
obs = obs_dict['obs']
states = obs_dict.get('rnn_states', None)
seq_length = obs_dict.get('seq_length', 1)
dones = obs_dict.get('dones', None)
bptt_len = obs_dict.get('bptt_len', 0)

if self.has_cnn:
# for obs shape 4
# input expected shape (B, W, H, C)
Expand All @@ -325,6 +323,8 @@ def forward(self, obs_dict):
c_out = c_out.contiguous().view(c_out.size(0), -1)

if self.has_rnn:
seq_length = obs_dict.get('seq_length', 1)

if not self.is_rnn_before_mlp:
a_out_in = a_out
c_out_in = c_out
Expand Down Expand Up @@ -359,9 +359,11 @@ def forward(self, obs_dict):
c_out = c_out.transpose(0,1)
a_out = a_out.contiguous().reshape(a_out.size()[0] * a_out.size()[1], -1)
c_out = c_out.contiguous().reshape(c_out.size()[0] * c_out.size()[1], -1)

if self.rnn_ln:
a_out = self.a_layer_norm(a_out)
c_out = self.c_layer_norm(c_out)

if type(a_states) is not tuple:
a_states = (a_states,)
c_states = (c_states,)
Expand Down Expand Up @@ -398,6 +400,8 @@ def forward(self, obs_dict):
out = out.flatten(1)

if self.has_rnn:
seq_length = obs_dict.get('seq_length', 1)

out_in = out
if not self.is_rnn_before_mlp:
out_in = out
Expand Down Expand Up @@ -703,13 +707,16 @@ def forward(self, obs_dict):
dones = obs_dict.get('dones', None)
bptt_len = obs_dict.get('bptt_len', 0)
states = obs_dict.get('rnn_states', None)
seq_length = obs_dict.get('seq_length', 1)

out = obs
out = self.cnn(out)
out = out.flatten(1)
out = self.flatten_act(out)

if self.has_rnn:
#seq_length = obs_dict['seq_length']
seq_length = obs_dict.get('seq_length', 1)

out_in = out
if not self.is_rnn_before_mlp:
out_in = out
Expand Down Expand Up @@ -769,20 +776,23 @@ def load(self, params):
self.is_multi_discrete = 'multi_discrete'in params['space']
self.value_activation = params.get('value_activation', 'None')
self.normalization = params.get('normalization', None)

if self.is_continuous:
self.space_config = params['space']['continuous']
self.fixed_sigma = self.space_config['fixed_sigma']
elif self.is_discrete:
self.space_config = params['space']['discrete']
elif self.is_multi_discrete:
self.space_config = params['space']['multi_discrete']
self.space_config = params['space']['multi_discrete']

self.has_rnn = 'rnn' in params
if self.has_rnn:
self.rnn_units = params['rnn']['units']
self.rnn_layers = params['rnn']['layers']
self.rnn_name = params['rnn']['name']
self.is_rnn_before_mlp = params['rnn'].get('before_mlp', False)
self.rnn_ln = params['rnn'].get('layer_norm', False)

self.has_cnn = True
self.permute_input = params['cnn'].get('permute_input', True)
self.conv_depths = params['cnn']['conv_depths']
Expand Down
25 changes: 18 additions & 7 deletions rl_games/common/a2c_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,16 @@ def __init__(self, base_name, params):
self.rewards_shaper = config['reward_shaper']
self.num_agents = self.env_info.get('agents', 1)
self.horizon_length = config['horizon_length']
self.seq_len = self.config.get('seq_length', 4)
self.bptt_len = self.config.get('bptt_length', self.seq_len) # not used right now. Didn't show that it is usefull

# seq_length is used only with rnn policy and value functions
if 'seq_len' in config:
print('WARNING: seq_len is deprecated, use seq_length instead')

self.seq_length = self.config.get('seq_length', 4)
print('seq_length:', self.seq_length)
self.bptt_len = self.config.get('bptt_length', self.seq_length) # not used right now. Didn't show that it is usefull
self.zero_rnn_on_done = self.config.get('zero_rnn_on_done', True)

self.normalize_advantage = config['normalize_advantage']
self.normalize_rms_advantage = config.get('normalize_rms_advantage', False)
self.normalize_input = self.config['normalize_input']
Expand All @@ -229,7 +236,7 @@ def __init__(self, base_name, params):
self.game_shaped_rewards = torch_ext.AverageMeter(self.value_size, self.games_to_track).to(self.ppo_device)
self.game_lengths = torch_ext.AverageMeter(1, self.games_to_track).to(self.ppo_device)
self.obs = None
self.games_num = self.config['minibatch_size'] // self.seq_len # it is used only for current rnn implementation
self.games_num = self.config['minibatch_size'] // self.seq_length # it is used only for current rnn implementation

self.batch_size = self.horizon_length * self.num_actors * self.num_agents
self.batch_size_envs = self.horizon_length * self.num_actors
Expand Down Expand Up @@ -463,8 +470,8 @@ def init_tensors(self):
self.rnn_states = [s.to(self.ppo_device) for s in self.rnn_states]

total_agents = self.num_agents * self.num_actors
num_seqs = self.horizon_length // self.seq_len
assert((self.horizon_length * total_agents // self.num_minibatches) % self.seq_len == 0)
num_seqs = self.horizon_length // self.seq_length
assert((self.horizon_length * total_agents // self.num_minibatches) % self.seq_length == 0)
self.mb_rnn_states = [torch.zeros((num_seqs, s.size()[0], total_agents, s.size()[2]), dtype = torch.float32, device=self.ppo_device) for s in self.rnn_states]

def init_rnn_from_model(self, model):
Expand Down Expand Up @@ -792,9 +799,9 @@ def play_steps_rnn(self):
step_time = 0.0

for n in range(self.horizon_length):
if n % self.seq_len == 0:
if n % self.seq_length == 0:
for s, mb_s in zip(self.rnn_states, mb_rnn_states):
mb_s[n // self.seq_len,:,:,:] = s
mb_s[n // self.seq_length,:,:,:] = s

if self.has_central_value:
self.central_value_net.pre_step_rnn(n)
Expand All @@ -804,6 +811,7 @@ def play_steps_rnn(self):
res_dict = self.get_masked_action_values(self.obs, masks)
else:
res_dict = self.get_action_values(self.obs)

self.rnn_states = res_dict['rnn_states']
self.experience_buffer.update_data('obses', n, self.obs['obs'])
self.experience_buffer.update_data('dones', n, self.dones.byte())
Expand Down Expand Up @@ -860,15 +868,18 @@ def play_steps_rnn(self):
mb_advs = self.discount_values(fdones, last_values, mb_fdones, mb_values, mb_rewards)
mb_returns = mb_advs + mb_values
batch_dict = self.experience_buffer.get_transformed_list(swap_and_flatten01, self.tensor_list)

batch_dict['returns'] = swap_and_flatten01(mb_returns)
batch_dict['played_frames'] = self.batch_size
states = []
for mb_s in mb_rnn_states:
t_size = mb_s.size()[0] * mb_s.size()[2]
h_size = mb_s.size()[3]
states.append(mb_s.permute(1,2,0,3).reshape(-1,t_size, h_size))

batch_dict['rnn_states'] = states
batch_dict['step_time'] = step_time

return batch_dict


Expand Down
20 changes: 12 additions & 8 deletions rl_games/common/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,23 @@
import copy
from torch.utils.data import Dataset


class PPODataset(Dataset):
def __init__(self, batch_size, minibatch_size, is_discrete, is_rnn, device, seq_len):

def __init__(self, batch_size, minibatch_size, is_discrete, is_rnn, device, seq_length):

self.is_rnn = is_rnn
self.seq_len = seq_len
self.seq_length = seq_length
self.batch_size = batch_size
self.minibatch_size = minibatch_size
self.device = device
self.length = self.batch_size // self.minibatch_size
self.is_discrete = is_discrete
self.is_continuous = not is_discrete
total_games = self.batch_size // self.seq_len
self.num_games_batch = self.minibatch_size // self.seq_len
total_games = self.batch_size // self.seq_length
self.num_games_batch = self.minibatch_size // self.seq_length
self.game_indexes = torch.arange(total_games, dtype=torch.long, device=self.device)
self.flat_indexes = torch.arange(total_games * self.seq_len, dtype=torch.long, device=self.device).reshape(total_games, self.seq_len)
self.flat_indexes = torch.arange(total_games * self.seq_length, dtype=torch.long, device=self.device).reshape(total_games, self.seq_length)

self.special_names = ['rnn_states']

Expand All @@ -34,9 +37,10 @@ def __len__(self):
def _get_item_rnn(self, idx):
gstart = idx * self.num_games_batch
gend = (idx + 1) * self.num_games_batch
start = gstart * self.seq_len
end = gend * self.seq_len
self.last_range = (start, end)
start = gstart * self.seq_length
end = gend * self.seq_length
self.last_range = (start, end)

input_dict = {}
for k,v in self.values_dict.items():
if k not in self.special_names:
Expand Down
3 changes: 2 additions & 1 deletion rl_games/common/vecenv.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import ray
from rl_games.common.ivecenv import IVecEnv
from rl_games.common.env_configurations import configurations
from rl_games.common.tr_helpers import dicts_to_dict_with_arrays
Expand Down Expand Up @@ -102,6 +101,8 @@ def __init__(self, config_name, num_actors, **kwargs):
self.num_actors = num_actors
self.use_torch = False
self.seed = kwargs.pop('seed', None)

import ray
self.remote_worker = ray.remote(RayWorker)
self.workers = [self.remote_worker.remote(self.config_name, kwargs) for i in range(self.num_actors)]

Expand Down
Loading