Skip to content

Commit

Permalink
Fixed applying minibatch_size_per_env (#287)
Browse files Browse the repository at this point in the history
  • Loading branch information
ViktorM authored Jun 13, 2024
1 parent 684df64 commit 66970f8
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion rl_games/common/a2c_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,6 @@ def __init__(self, base_name, params):
self.game_shaped_rewards = torch_ext.AverageMeter(self.value_size, self.games_to_track).to(self.ppo_device)
self.game_lengths = torch_ext.AverageMeter(1, self.games_to_track).to(self.ppo_device)
self.obs = None
self.games_num = self.config['minibatch_size'] // self.seq_length # it is used only for current rnn implementation

self.batch_size = self.horizon_length * self.num_actors * self.num_agents
self.batch_size_envs = self.horizon_length * self.num_actors
Expand All @@ -245,6 +244,16 @@ def __init__(self, base_name, params):
self.minibatch_size_per_env = self.config.get('minibatch_size_per_env', 0)
self.minibatch_size = self.config.get('minibatch_size', self.num_actors * self.minibatch_size_per_env)

# either minibatch_size_per_env or minibatch_size should be present in a config
# if both are present, minibatch_size is used
# otherwise minibatch_size_per_env is used minibatch_size_per_env is used to calculate minibatch_size
self.minibatch_size_per_env = self.config.get('minibatch_size_per_env', 0)
self.minibatch_size = self.config.get('minibatch_size', self.num_actors * self.minibatch_size_per_env)

assert(self.minibatch_size > 0)

self.games_num = self.minibatch_size // self.seq_length # it is used only for current rnn implementation

self.num_minibatches = self.batch_size // self.minibatch_size
assert(self.batch_size % self.minibatch_size == 0)

Expand Down

0 comments on commit 66970f8

Please sign in to comment.