Skip to content

Commit

Permalink
More seq_length work.
Browse files Browse the repository at this point in the history
  • Loading branch information
ViktorM committed Sep 25, 2023
1 parent 61e998f commit 06756a0
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 7 deletions.
2 changes: 1 addition & 1 deletion rl_games/algos_torch/central_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
class CentralValueTrain(nn.Module):

def __init__(self, state_shape, value_size, ppo_device, num_agents, horizon_length, num_actors, num_actions,
seq_len, normalize_value, network, config, writter, max_epochs, multi_gpu, zero_rnn_on_done):
seq_length, normalize_value, network, config, writter, max_epochs, multi_gpu, zero_rnn_on_done):
nn.Module.__init__(self)

self.ppo_device = ppo_device
Expand Down
5 changes: 4 additions & 1 deletion rl_games/common/a2c_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,10 @@ def __init__(self, base_name, params):
self.rewards_shaper = config['reward_shaper']
self.num_agents = self.env_info.get('agents', 1)
self.horizon_length = config['horizon_length']

# seq_length is used only with rnn policy and value functions
self.seq_length = self.config.get('seq_length', 4)
print('seq_length:', self.seq_length)
self.bptt_len = self.config.get('bptt_length', self.seq_length) # not used right now. Didn't show that it is usefull
self.zero_rnn_on_done = self.config.get('zero_rnn_on_done', True)
self.normalize_advantage = config['normalize_advantage']
Expand Down Expand Up @@ -794,7 +797,7 @@ def play_steps_rnn(self):
for n in range(self.horizon_length):
if n % self.seq_length == 0:
for s, mb_s in zip(self.rnn_states, mb_rnn_states):
mb_s[n // self.seq_len,:,:,:] = s
mb_s[n // self.seq_length,:,:,:] = s

if self.has_central_value:
self.central_value_net.pre_step_rnn(n)
Expand Down
11 changes: 6 additions & 5 deletions rl_games/common/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ def __init__(self, batch_size, minibatch_size, is_discrete, is_rnn, device, seq_
self.is_discrete = is_discrete
self.is_continuous = not is_discrete
total_games = self.batch_size // self.seq_length
self.num_games_batch = self.minibatch_size // self.seq_len
self.num_games_batch = self.minibatch_size // self.seq_length
self.game_indexes = torch.arange(total_games, dtype=torch.long, device=self.device)
self.flat_indexes = torch.arange(total_games * self.seq_len, dtype=torch.long, device=self.device).reshape(total_games, self.seq_length)
self.flat_indexes = torch.arange(total_games * self.seq_length, dtype=torch.long, device=self.device).reshape(total_games, self.seq_length)

self.special_names = ['rnn_states']

Expand All @@ -36,9 +36,10 @@ def __len__(self):
def _get_item_rnn(self, idx):
gstart = idx * self.num_games_batch
gend = (idx + 1) * self.num_games_batch
start = gstart * self.seq_len
end = gend * self.seq_len
self.last_range = (start, end)
start = gstart * self.seq_length
end = gend * self.seq_length
self.last_range = (start, end)

input_dict = {}
for k,v in self.values_dict.items():
if k not in self.special_names:
Expand Down

0 comments on commit 06756a0

Please sign in to comment.