diff --git a/README.md b/README.md index 6410f09d..7393c8ff 100644 --- a/README.md +++ b/README.md @@ -305,6 +305,7 @@ Additional environment supported properties and functions * Added evaluation feature for inferencing during training. Checkpoints from training process can be automatically picked up and updated in the inferencing process when enabled. * Added get/set API for runtime update of rl training parameters. Thanks to @ArthurAllshire for the initial version of fast PBT code. * Fixed SAC not loading weights properly. +* Removed Ray dependency for use cases it's not required. 1.6.0 diff --git a/rl_games/algos_torch/network_builder.py b/rl_games/algos_torch/network_builder.py index 26434027..73447607 100644 --- a/rl_games/algos_torch/network_builder.py +++ b/rl_games/algos_torch/network_builder.py @@ -323,7 +323,7 @@ def forward(self, obs_dict): c_out = c_out.contiguous().view(c_out.size(0), -1) if self.has_rnn: - seq_length = obs_dict.get['seq_length'] + seq_length = obs_dict['seq_length'] if not self.is_rnn_before_mlp: a_out_in = a_out @@ -398,7 +398,7 @@ def forward(self, obs_dict): out = out.flatten(1) if self.has_rnn: - seq_length = obs_dict.get['seq_length'] + seq_length = obs_dict['seq_length'] out_in = out if not self.is_rnn_before_mlp: @@ -712,7 +712,7 @@ def forward(self, obs_dict): out = self.flatten_act(out) if self.has_rnn: - seq_length = obs_dict.get['seq_length'] + seq_length = obs_dict['seq_length'] out_in = out if not self.is_rnn_before_mlp: diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index a3c53df7..b1cda019 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -466,7 +466,7 @@ def init_tensors(self): self.rnn_states = [s.to(self.ppo_device) for s in self.rnn_states] total_agents = self.num_agents * self.num_actors - num_seqs = self.horizon_length // self.seq_len + num_seqs = self.horizon_length // self.seq_length assert((self.horizon_length * total_agents // self.num_minibatches) % self.seq_length == 0) self.mb_rnn_states = [torch.zeros((num_seqs, s.size()[0], total_agents, s.size()[2]), dtype = torch.float32, device=self.ppo_device) for s in self.rnn_states]