Skip to content

Commit

Permalink
added mini batch size per env (#159)
Browse files Browse the repository at this point in the history
  • Loading branch information
Denys88 authored May 21, 2022
1 parent f0801ab commit a320613
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
8 changes: 5 additions & 3 deletions rl_games/algos_torch/central_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,10 @@ def __init__(self, state_shape, value_size, ppo_device, num_agents, horizon_leng
self.scheduler = schedulers.IdentityScheduler()

self.mini_epoch = config['mini_epochs']
self.mini_batch = config['minibatch_size']
self.num_minibatches = self.horizon_length * self.num_actors // self.mini_batch
assert(('minibatch_size_per_env' in self.config) or ('minibatch_size' in self.config))
self.minibatch_size_per_env = self.config.get('minibatch_size_per_env', 0)
self.minibatch_size = self.config.get('minibatch_size', self.num_actors * self.minibatch_size_per_env)
self.num_minibatches = self.horizon_length * self.num_actors // self.minibatch_size
self.clip_value = config['clip_value']

self.writter = writter
Expand All @@ -71,7 +73,7 @@ def __init__(self, state_shape, value_size, ppo_device, num_agents, horizon_leng
assert ((self.horizon_length * total_agents // self.num_minibatches) % self.seq_len == 0)
self.mb_rnn_states = [ torch.zeros((num_seqs, s.size()[0], total_agents, s.size()[2]), dtype=torch.float32, device=self.ppo_device) for s in self.rnn_states]

self.dataset = datasets.PPODataset(self.batch_size, self.mini_batch, True, self.is_rnn, self.ppo_device, self.seq_len)
self.dataset = datasets.PPODataset(self.batch_size, self.minibatch_size, True, self.is_rnn, self.ppo_device, self.seq_len)

def update_lr(self, lr):
if self.multi_gpu:
Expand Down
6 changes: 4 additions & 2 deletions rl_games/common/a2c_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,14 +170,16 @@ def __init__(self, base_name, params):
self.tau = self.config['tau']

self.games_to_track = self.config.get('games_to_track', 100)
print(self.ppo_device)
print('current training device:', self.ppo_device)
self.game_rewards = torch_ext.AverageMeter(self.value_size, self.games_to_track).to(self.ppo_device)
self.game_lengths = torch_ext.AverageMeter(1, self.games_to_track).to(self.ppo_device)
self.obs = None
self.games_num = self.config['minibatch_size'] // self.seq_len # it is used only for current rnn implementation
self.batch_size = self.horizon_length * self.num_actors * self.num_agents
self.batch_size_envs = self.horizon_length * self.num_actors
self.minibatch_size = self.config['minibatch_size']
assert(('minibatch_size_per_env' in self.config) or ('minibatch_size' in self.config))
self.minibatch_size_per_env = self.config.get('minibatch_size_per_env', 0)
self.minibatch_size = self.config.get('minibatch_size', self.num_actors * self.minibatch_size_per_env)
self.mini_epochs_num = self.config['mini_epochs']
self.num_minibatches = self.batch_size // self.minibatch_size
assert(self.batch_size % self.minibatch_size == 0)
Expand Down

0 comments on commit a320613

Please sign in to comment.