From c645d9a403a3b838c8e7cc217e8eec3274bdf740 Mon Sep 17 00:00:00 2001 From: ViktorM Date: Tue, 26 Sep 2023 07:52:02 -0700 Subject: [PATCH] Readme update. --- README.md | 2 ++ pyproject.toml | 1 - rl_games/algos_torch/network_builder.py | 4 ++-- rl_games/common/a2c_common.py | 4 ++++ 4 files changed, 8 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 7393c8ff..f621ae19 100644 --- a/README.md +++ b/README.md @@ -306,6 +306,8 @@ Additional environment supported properties and functions * Added get/set API for runtime update of rl training parameters. Thanks to @ArthurAllshire for the initial version of fast PBT code. * Fixed SAC not loading weights properly. * Removed Ray dependency for use cases it's not required. +* Added warning for using deprecated 'seq_len' instead of 'seq_length' in configs with RNN networks. + 1.6.0 diff --git a/pyproject.toml b/pyproject.toml index bed33c78..e73c4c42 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,6 @@ tensorboardX = "^2.5" PyYAML = "^6.0" psutil = "^5.9.0" setproctitle = "^1.2.2" -ray = "^1.11.0" opencv-python = "^4.5.5" wandb = "^0.12.11" diff --git a/rl_games/algos_torch/network_builder.py b/rl_games/algos_torch/network_builder.py index ce163e48..ce5651c5 100644 --- a/rl_games/algos_torch/network_builder.py +++ b/rl_games/algos_torch/network_builder.py @@ -323,7 +323,6 @@ def forward(self, obs_dict): c_out = c_out.contiguous().view(c_out.size(0), -1) if self.has_rnn: - #seq_length = obs_dict['seq_length'] seq_length = obs_dict.get('seq_length', 1) if not self.is_rnn_before_mlp: @@ -360,9 +359,11 @@ def forward(self, obs_dict): c_out = c_out.transpose(0,1) a_out = a_out.contiguous().reshape(a_out.size()[0] * a_out.size()[1], -1) c_out = c_out.contiguous().reshape(c_out.size()[0] * c_out.size()[1], -1) + if self.rnn_ln: a_out = self.a_layer_norm(a_out) c_out = self.c_layer_norm(c_out) + if type(a_states) is not tuple: a_states = (a_states,) c_states = (c_states,) @@ -399,7 +400,6 @@ def forward(self, obs_dict): out = out.flatten(1) if self.has_rnn: - #seq_length = obs_dict['seq_length'] seq_length = obs_dict.get('seq_length', 1) out_in = out diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index 9c4f1981..63b90c07 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -204,10 +204,14 @@ def __init__(self, base_name, params): self.horizon_length = config['horizon_length'] # seq_length is used only with rnn policy and value functions + if 'seq_len' in config: + print('WARNING: seq_len is deprecated, use seq_length instead') + self.seq_length = self.config.get('seq_length', 4) print('seq_length:', self.seq_length) self.bptt_len = self.config.get('bptt_length', self.seq_length) # not used right now. Didn't show that it is usefull self.zero_rnn_on_done = self.config.get('zero_rnn_on_done', True) + self.normalize_advantage = config['normalize_advantage'] self.normalize_rms_advantage = config.get('normalize_rms_advantage', False) self.normalize_input = self.config['normalize_input']