Skip to content

Commit

Permalink
experiment: sampling policy gradient
Browse files Browse the repository at this point in the history
  • Loading branch information
YuriCat committed Sep 22, 2023
1 parent bc1d4bf commit 70ee321
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 8 deletions.
4 changes: 3 additions & 1 deletion handyrl/envs/tictactoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,13 @@ def __init__(self):
self.head_p = Head((filters, 3, 3), 2, 9)
self.head_v = Head((filters, 3, 3), 1, 1)

def forward(self, x, hidden=None):
def forward(self, x, hidden=None, actions=None):
h = F.relu(self.conv(x))
for block in self.blocks:
h = F.relu(block(h))
h_p = self.head_p(h)
if actions is not None:
h_p = h_p.gather(-1, actions)
h_v = self.head_v(h)

return {'policy': h_p, 'value': torch.tanh(h_v)}
Expand Down
38 changes: 31 additions & 7 deletions handyrl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,15 @@ def replace_none(a, b):

progress = np.arange(ep['start'], ep['end'], dtype=np.float32)[..., np.newaxis] / ep['total']

sampling_rate = 1
sampled_action_count = int((9 - 1) * sampling_rate)
all_actions = np.arange(0, 9 - 1, dtype=np.int64)
sampled_actions = np.array([np.random.choice(all_actions, size=(sampled_action_count), replace=False) for _ in range(np.prod(act.shape[:-1]))]).reshape(*act.shape[:-1], -1)
sampled_actions += (sampled_actions >= act).astype(np.int64)
cat_actions = np.concatenate([act, sampled_actions], -1)

amask = np.take_along_axis(amask, cat_actions, -1)

# pad each array if step length is short
batch_steps = args['burn_in_steps'] + args['forward_steps']
if len(tmask) < batch_steps:
Expand All @@ -104,12 +113,13 @@ def replace_none(a, b):
omask = np.pad(omask, [(pad_len_b, pad_len_a), (0, 0), (0, 0)], 'constant', constant_values=0)
amask = np.pad(amask, [(pad_len_b, pad_len_a), (0, 0), (0, 0)], 'constant', constant_values=1e32)
progress = np.pad(progress, [(pad_len_b, pad_len_a), (0, 0)], 'constant', constant_values=1)
cat_actions = np.pad(cat_actions, [(pad_len_b, pad_len_a), (0, 0), (0, 0)], 'constant', constant_values=0)

obss.append(obs)
datum.append((prob, v, act, oc, rew, ret, emask, tmask, omask, amask, progress))
datum.append((prob, v, act, oc, rew, ret, emask, tmask, omask, amask, progress, cat_actions))

obs = to_torch(bimap_r(obs_zeros, rotate(obss), lambda _, o: np.array(o)))
prob, v, act, oc, rew, ret, emask, tmask, omask, amask, progress = [to_torch(np.array(val)) for val in zip(*datum)]
prob, v, act, oc, rew, ret, emask, tmask, omask, amask, progress, cat_actions = [to_torch(np.array(val)) for val in zip(*datum)]

return {
'observation': obs,
Expand All @@ -121,6 +131,7 @@ def replace_none(a, b):
'turn_mask': tmask, 'observation_mask': omask,
'action_mask': amask,
'progress': progress,
'sampled_actions': cat_actions,
}


Expand All @@ -142,7 +153,8 @@ def forward_prediction(model, hidden, batch, args):
if hidden is None:
# feed-forward neural network
obs = map_r(observations, lambda o: o.flatten(0, 2)) # (..., B * T * P or 1, ...)
outputs = model(obs, None)
sampled_actions = batch['sampled_actions'].flatten(0, 2)
outputs = model(obs, None, sampled_actions)
outputs = map_r(outputs, lambda o: o.unflatten(0, batch_shape)) # (..., B, T, P or 1, ...)
else:
# sequential computation with RNN
Expand Down Expand Up @@ -199,7 +211,10 @@ def compose_losses(outputs, log_selected_policies, total_advantages, targets, ba
losses = {}
dcnt = tmasks.sum().item()

losses['p'] = (-log_selected_policies * total_advantages).mul(tmasks).sum()
p_selected_loss = -outputs['policy'][:, :, :, :1].mul(total_advantages).mul(tmasks).sum()
p_prob_loss = outputs['policy'].mul(F.softmax(outputs['policy'].detach(), -1) * targets['mod_ratio']).mul(total_advantages).mul(tmasks).sum()
losses['p'] = p_selected_loss + p_prob_loss

if 'value' in outputs:
losses['v'] = ((outputs['value'] - targets['value']) ** 2).mul(omasks).sum() / 2
if 'return' in outputs:
Expand Down Expand Up @@ -229,7 +244,17 @@ def compute_loss(batch, model, hidden, args):
clip_rho_threshold, clip_c_threshold = 1.0, 1.0

log_selected_b_policies = torch.log(torch.clamp(batch['selected_prob'], 1e-16, 1)) * emasks
log_selected_t_policies = F.log_softmax(outputs['policy'], dim=-1).gather(-1, actions) * emasks

targets = {}

sampled_count = outputs['policy'].size(-1)
total_count = 9
naive_selected_t_policies = F.softmax(outputs['policy'], dim=-1)
estimated_sampled_action_b_prob = (1 - batch['selected_prob']) * ((sampled_count - 1) / (total_count - 1))
targets['mod_ratio'] = batch['selected_prob'] + estimated_sampled_action_b_prob
selected_t_policies = naive_selected_t_policies * targets['mod_ratio']

log_selected_t_policies = torch.log(torch.clamp(selected_t_policies, 1e-16, 1)) * emasks

# thresholds of importance sampling
log_rhos = log_selected_t_policies.detach() - log_selected_b_policies
Expand All @@ -248,7 +273,6 @@ def compute_loss(batch, model, hidden, args):
outputs_nograd['value'] = values_nograd * emasks + batch['outcome'] * (1 - emasks)

# compute targets and advantage
targets = {}
advantages = {}

value_args = outputs_nograd.get('value', None), batch['outcome'], None, args['lambda'], 1, clipped_rhos, cs, value_target_masks
Expand Down Expand Up @@ -324,7 +348,7 @@ def __init__(self, args, model):
self.args = args
self.gpu = torch.cuda.device_count()
self.model = model
self.default_lr = 3e-8
self.default_lr = 3e-6
self.data_cnt_ema = self.args['batch_size'] * self.args['forward_steps']
self.params = list(self.model.parameters())
lr = self.default_lr * self.data_cnt_ema
Expand Down

0 comments on commit 70ee321

Please sign in to comment.