Skip to content

Commit

Permalink
Readme update.
Browse files Browse the repository at this point in the history
  • Loading branch information
ViktorM committed Sep 26, 2023
1 parent 94e0ce1 commit c645d9a
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 3 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,8 @@ Additional environment supported properties and functions
* Added get/set API for runtime update of rl training parameters. Thanks to @ArthurAllshire for the initial version of fast PBT code.
* Fixed SAC not loading weights properly.
* Removed Ray dependency for use cases it's not required.
* Added warning for using deprecated 'seq_len' instead of 'seq_length' in configs with RNN networks.


1.6.0

Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ tensorboardX = "^2.5"
PyYAML = "^6.0"
psutil = "^5.9.0"
setproctitle = "^1.2.2"
ray = "^1.11.0"
opencv-python = "^4.5.5"
wandb = "^0.12.11"

Expand Down
4 changes: 2 additions & 2 deletions rl_games/algos_torch/network_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,6 @@ def forward(self, obs_dict):
c_out = c_out.contiguous().view(c_out.size(0), -1)

if self.has_rnn:
#seq_length = obs_dict['seq_length']
seq_length = obs_dict.get('seq_length', 1)

if not self.is_rnn_before_mlp:
Expand Down Expand Up @@ -360,9 +359,11 @@ def forward(self, obs_dict):
c_out = c_out.transpose(0,1)
a_out = a_out.contiguous().reshape(a_out.size()[0] * a_out.size()[1], -1)
c_out = c_out.contiguous().reshape(c_out.size()[0] * c_out.size()[1], -1)

if self.rnn_ln:
a_out = self.a_layer_norm(a_out)
c_out = self.c_layer_norm(c_out)

if type(a_states) is not tuple:
a_states = (a_states,)
c_states = (c_states,)
Expand Down Expand Up @@ -399,7 +400,6 @@ def forward(self, obs_dict):
out = out.flatten(1)

if self.has_rnn:
#seq_length = obs_dict['seq_length']
seq_length = obs_dict.get('seq_length', 1)

out_in = out
Expand Down
4 changes: 4 additions & 0 deletions rl_games/common/a2c_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,10 +204,14 @@ def __init__(self, base_name, params):
self.horizon_length = config['horizon_length']

# seq_length is used only with rnn policy and value functions
if 'seq_len' in config:
print('WARNING: seq_len is deprecated, use seq_length instead')

self.seq_length = self.config.get('seq_length', 4)
print('seq_length:', self.seq_length)
self.bptt_len = self.config.get('bptt_length', self.seq_length) # not used right now. Didn't show that it is usefull
self.zero_rnn_on_done = self.config.get('zero_rnn_on_done', True)

self.normalize_advantage = config['normalize_advantage']
self.normalize_rms_advantage = config.get('normalize_rms_advantage', False)
self.normalize_input = self.config['normalize_input']
Expand Down

0 comments on commit c645d9a

Please sign in to comment.