Skip to content

Commit

Permalink
Fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
ViktorM committed Sep 25, 2023
1 parent 06756a0 commit 872e767
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 4 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ Additional environment supported properties and functions
* Added evaluation feature for inferencing during training. Checkpoints from training process can be automatically picked up and updated in the inferencing process when enabled.
* Added get/set API for runtime update of rl training parameters. Thanks to @ArthurAllshire for the initial version of fast PBT code.
* Fixed SAC not loading weights properly.
* Removed Ray dependency for use cases it's not required.

1.6.0

Expand Down
6 changes: 3 additions & 3 deletions rl_games/algos_torch/network_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def forward(self, obs_dict):
c_out = c_out.contiguous().view(c_out.size(0), -1)

if self.has_rnn:
seq_length = obs_dict.get['seq_length']
seq_length = obs_dict['seq_length']

if not self.is_rnn_before_mlp:
a_out_in = a_out
Expand Down Expand Up @@ -398,7 +398,7 @@ def forward(self, obs_dict):
out = out.flatten(1)

if self.has_rnn:
seq_length = obs_dict.get['seq_length']
seq_length = obs_dict['seq_length']

out_in = out
if not self.is_rnn_before_mlp:
Expand Down Expand Up @@ -712,7 +712,7 @@ def forward(self, obs_dict):
out = self.flatten_act(out)

if self.has_rnn:
seq_length = obs_dict.get['seq_length']
seq_length = obs_dict['seq_length']

out_in = out
if not self.is_rnn_before_mlp:
Expand Down
2 changes: 1 addition & 1 deletion rl_games/common/a2c_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ def init_tensors(self):
self.rnn_states = [s.to(self.ppo_device) for s in self.rnn_states]

total_agents = self.num_agents * self.num_actors
num_seqs = self.horizon_length // self.seq_len
num_seqs = self.horizon_length // self.seq_length
assert((self.horizon_length * total_agents // self.num_minibatches) % self.seq_length == 0)
self.mb_rnn_states = [torch.zeros((num_seqs, s.size()[0], total_agents, s.size()[2]), dtype = torch.float32, device=self.ppo_device) for s in self.rnn_states]

Expand Down

0 comments on commit 872e767

Please sign in to comment.