Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PPO minibatch sampling #327

Merged
merged 4 commits into from
May 4, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions slm_lab/agent/algorithm/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class PPO(ActorCritic):
"start_step": 100,
"end_step": 5000,
},
"minibatch_size": 256,
"training_frequency": 1,
"training_epoch": 8,
"normalize_state": true
Expand All @@ -72,6 +73,7 @@ def init_algorithm_params(self):
action_policy='default',
explore_var_spec=None,
entropy_coef_spec=None,
minibatch_size=8,
val_loss_coef=1.0,
))
util.set_attr(self, self.algorithm_spec, [
Expand All @@ -84,6 +86,7 @@ def init_algorithm_params(self):
'clip_eps_spec',
'entropy_coef_spec',
'val_loss_coef',
'minibatch_size',
'training_frequency', # horizon
'training_epoch',
'normalize_state',
Expand Down Expand Up @@ -166,15 +169,22 @@ def train(self):
batch = self.sample()
_pdparams, v_preds = self.calc_pdparam_v(batch)
advs, v_targets = self.calc_advs_v_targets(batch, v_preds)
batch['advs'] = advs
batch['v_targets'] = v_targets
# piggy back on batch, but remember to not pack or unpack
batch['advs'], batch['v_targets'] = advs, v_targets
if self.body.env.is_venv: # unpack if venv for minibatch sampling
for k, v in batch.items():
if k not in ('advs', 'v_targets'):
batch[k] = math_util.venv_unpack(v)
total_loss = torch.tensor(0.0)
for _ in range(self.training_epoch):
minibatch = batch # TODO sample minibatch from batch with size < length of batch
advs = batch['advs']
v_targets = batch['v_targets']
pdparams, v_preds = self.calc_pdparam_v(batch)
policy_loss = self.calc_policy_loss(batch, pdparams, advs) # from actor
minibatch = util.sample_minibatch(batch, self.minibatch_size)
if self.body.env.is_venv: # re-pack to restore proper shape
for k, v in minibatch.items():
if k not in ('advs', 'v_targets'):
minibatch[k] = math_util.venv_pack(v, self.body.env.num_envs)
advs, v_targets = minibatch['advs'], minibatch['v_targets']
pdparams, v_preds = self.calc_pdparam_v(minibatch)
policy_loss = self.calc_policy_loss(minibatch, pdparams, advs) # from actor
val_loss = self.calc_val_loss(v_preds, v_targets) # from critic
if self.shared: # shared network
loss = policy_loss + val_loss
Expand Down
51 changes: 30 additions & 21 deletions slm_lab/lib/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,15 @@ def s_get(cls, attr_path):
return res


def sample_minibatch(batch, mb_size):
'''Sample a minibatch within a batch that is produced by to_torch_batch()'''
size = len(batch['rewards'])
assert mb_size < size, f'Minibatch size {mb_size} must be < batch size {size}'
minibatch_idxs = np.random.randint(size, size=mb_size)
minibatch = {k: v[minibatch_idxs] for k, v in batch.items()}
return minibatch


def self_desc(cls):
'''Method to get self description, used at init.'''
desc_list = [f'{get_class_name(cls)}:']
Expand Down Expand Up @@ -575,6 +584,27 @@ def set_attr(obj, attr_dict, keys=None):
return obj


def set_cuda_id(spec, info_space):
'''Use trial and session id to hash and modulo cuda device count for a cuda_id to maximize device usage. Sets the net_spec for the base Net class to pick up.'''
# Don't trigger any cuda call if not using GPU. Otherwise will break multiprocessing on machines with CUDA.
# see issues https://github.com/pytorch/pytorch/issues/334 https://github.com/pytorch/pytorch/issues/3491 https://github.com/pytorch/pytorch/issues/9996
for agent_spec in spec['agent']:
if not agent_spec['net'].get('gpu'):
return
trial_idx = info_space.get('trial') or 0
session_idx = info_space.get('session') or 0
job_idx = trial_idx * spec['meta']['max_session'] + session_idx
job_idx += int(os.environ.get('CUDA_ID_OFFSET', 0))
device_count = torch.cuda.device_count()
if device_count == 0:
cuda_id = None
else:
cuda_id = job_idx % device_count

for agent_spec in spec['agent']:
agent_spec['net']['cuda_id'] = cuda_id


def set_logger(spec, info_space, logger, unit=None):
'''Set the logger for a lab unit give its spec and info_space'''
os.environ['PREPATH'] = get_prepath(spec, info_space, unit=unit)
Expand Down Expand Up @@ -662,27 +692,6 @@ def to_torch_batch(batch, device, is_episodic):
return batch


def set_cuda_id(spec, info_space):
'''Use trial and session id to hash and modulo cuda device count for a cuda_id to maximize device usage. Sets the net_spec for the base Net class to pick up.'''
# Don't trigger any cuda call if not using GPU. Otherwise will break multiprocessing on machines with CUDA.
# see issues https://github.com/pytorch/pytorch/issues/334 https://github.com/pytorch/pytorch/issues/3491 https://github.com/pytorch/pytorch/issues/9996
for agent_spec in spec['agent']:
if not agent_spec['net'].get('gpu'):
return
trial_idx = info_space.get('trial') or 0
session_idx = info_space.get('session') or 0
job_idx = trial_idx * spec['meta']['max_session'] + session_idx
job_idx += int(os.environ.get('CUDA_ID_OFFSET', 0))
device_count = torch.cuda.device_count()
if device_count == 0:
cuda_id = None
else:
cuda_id = job_idx % device_count

for agent_spec in spec['agent']:
agent_spec['net']['cuda_id'] = cuda_id


def write(data, data_path):
'''
Universal data writing method with smart data parsing
Expand Down
5 changes: 3 additions & 2 deletions slm_lab/spec/experimental/ppo_pong.json
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
"end_step": 0
},
"val_loss_coef": 0.5,
"training_frequency": 32,
"minibatch_size": 256,
"training_frequency": 128,
"training_epoch": 4,
"normalize_state": false
},
Expand Down Expand Up @@ -68,7 +69,7 @@
}],
"env": [{
"name": "PongNoFrameskip-v4",
"num_envs": 16,
"num_envs": 8,
"max_t": null,
"max_tick": 1e7
}],
Expand Down