Skip to content

Commit

Permalink
Merge pull request #327 from kengz/ppo-minibatch
Browse files Browse the repository at this point in the history
Add PPO minibatch sampling
  • Loading branch information
kengz authored May 4, 2019
2 parents 10fa7d1 + a3d9e32 commit 601d9bb
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 30 deletions.
24 changes: 17 additions & 7 deletions slm_lab/agent/algorithm/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class PPO(ActorCritic):
"start_step": 100,
"end_step": 5000,
},
"minibatch_size": 256,
"training_frequency": 1,
"training_epoch": 8,
"normalize_state": true
Expand All @@ -72,6 +73,7 @@ def init_algorithm_params(self):
action_policy='default',
explore_var_spec=None,
entropy_coef_spec=None,
minibatch_size=8,
val_loss_coef=1.0,
))
util.set_attr(self, self.algorithm_spec, [
Expand All @@ -84,6 +86,7 @@ def init_algorithm_params(self):
'clip_eps_spec',
'entropy_coef_spec',
'val_loss_coef',
'minibatch_size',
'training_frequency', # horizon
'training_epoch',
'normalize_state',
Expand Down Expand Up @@ -166,15 +169,22 @@ def train(self):
batch = self.sample()
_pdparams, v_preds = self.calc_pdparam_v(batch)
advs, v_targets = self.calc_advs_v_targets(batch, v_preds)
batch['advs'] = advs
batch['v_targets'] = v_targets
# piggy back on batch, but remember to not pack or unpack
batch['advs'], batch['v_targets'] = advs, v_targets
if self.body.env.is_venv: # unpack if venv for minibatch sampling
for k, v in batch.items():
if k not in ('advs', 'v_targets'):
batch[k] = math_util.venv_unpack(v)
total_loss = torch.tensor(0.0)
for _ in range(self.training_epoch):
minibatch = batch # TODO sample minibatch from batch with size < length of batch
advs = batch['advs']
v_targets = batch['v_targets']
pdparams, v_preds = self.calc_pdparam_v(batch)
policy_loss = self.calc_policy_loss(batch, pdparams, advs) # from actor
minibatch = util.sample_minibatch(batch, self.minibatch_size)
if self.body.env.is_venv: # re-pack to restore proper shape
for k, v in minibatch.items():
if k not in ('advs', 'v_targets'):
minibatch[k] = math_util.venv_pack(v, self.body.env.num_envs)
advs, v_targets = minibatch['advs'], minibatch['v_targets']
pdparams, v_preds = self.calc_pdparam_v(minibatch)
policy_loss = self.calc_policy_loss(minibatch, pdparams, advs) # from actor
val_loss = self.calc_val_loss(v_preds, v_targets) # from critic
if self.shared: # shared network
loss = policy_loss + val_loss
Expand Down
51 changes: 30 additions & 21 deletions slm_lab/lib/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,15 @@ def s_get(cls, attr_path):
return res


def sample_minibatch(batch, mb_size):
'''Sample a minibatch within a batch that is produced by to_torch_batch()'''
size = len(batch['rewards'])
assert mb_size < size, f'Minibatch size {mb_size} must be < batch size {size}'
minibatch_idxs = np.random.randint(size, size=mb_size)
minibatch = {k: v[minibatch_idxs] for k, v in batch.items()}
return minibatch


def self_desc(cls):
'''Method to get self description, used at init.'''
desc_list = [f'{get_class_name(cls)}:']
Expand Down Expand Up @@ -575,6 +584,27 @@ def set_attr(obj, attr_dict, keys=None):
return obj


def set_cuda_id(spec, info_space):
'''Use trial and session id to hash and modulo cuda device count for a cuda_id to maximize device usage. Sets the net_spec for the base Net class to pick up.'''
# Don't trigger any cuda call if not using GPU. Otherwise will break multiprocessing on machines with CUDA.
# see issues https://github.com/pytorch/pytorch/issues/334 https://github.com/pytorch/pytorch/issues/3491 https://github.com/pytorch/pytorch/issues/9996
for agent_spec in spec['agent']:
if not agent_spec['net'].get('gpu'):
return
trial_idx = info_space.get('trial') or 0
session_idx = info_space.get('session') or 0
job_idx = trial_idx * spec['meta']['max_session'] + session_idx
job_idx += int(os.environ.get('CUDA_ID_OFFSET', 0))
device_count = torch.cuda.device_count()
if device_count == 0:
cuda_id = None
else:
cuda_id = job_idx % device_count

for agent_spec in spec['agent']:
agent_spec['net']['cuda_id'] = cuda_id


def set_logger(spec, info_space, logger, unit=None):
'''Set the logger for a lab unit give its spec and info_space'''
os.environ['PREPATH'] = get_prepath(spec, info_space, unit=unit)
Expand Down Expand Up @@ -662,27 +692,6 @@ def to_torch_batch(batch, device, is_episodic):
return batch


def set_cuda_id(spec, info_space):
'''Use trial and session id to hash and modulo cuda device count for a cuda_id to maximize device usage. Sets the net_spec for the base Net class to pick up.'''
# Don't trigger any cuda call if not using GPU. Otherwise will break multiprocessing on machines with CUDA.
# see issues https://github.com/pytorch/pytorch/issues/334 https://github.com/pytorch/pytorch/issues/3491 https://github.com/pytorch/pytorch/issues/9996
for agent_spec in spec['agent']:
if not agent_spec['net'].get('gpu'):
return
trial_idx = info_space.get('trial') or 0
session_idx = info_space.get('session') or 0
job_idx = trial_idx * spec['meta']['max_session'] + session_idx
job_idx += int(os.environ.get('CUDA_ID_OFFSET', 0))
device_count = torch.cuda.device_count()
if device_count == 0:
cuda_id = None
else:
cuda_id = job_idx % device_count

for agent_spec in spec['agent']:
agent_spec['net']['cuda_id'] = cuda_id


def write(data, data_path):
'''
Universal data writing method with smart data parsing
Expand Down
5 changes: 3 additions & 2 deletions slm_lab/spec/experimental/ppo_pong.json
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
"end_step": 0
},
"val_loss_coef": 0.5,
"training_frequency": 32,
"minibatch_size": 256,
"training_frequency": 128,
"training_epoch": 4,
"normalize_state": false
},
Expand Down Expand Up @@ -68,7 +69,7 @@
}],
"env": [{
"name": "PongNoFrameskip-v4",
"num_envs": 16,
"num_envs": 8,
"max_t": null,
"max_tick": 1e7
}],
Expand Down

0 comments on commit 601d9bb

Please sign in to comment.