diff --git a/handyrl/envs/tictactoe.py b/handyrl/envs/tictactoe.py index 2c27809c..b0296dd3 100755 --- a/handyrl/envs/tictactoe.py +++ b/handyrl/envs/tictactoe.py @@ -59,11 +59,13 @@ def __init__(self): self.head_p = Head((filters, 3, 3), 2, 9) self.head_v = Head((filters, 3, 3), 1, 1) - def forward(self, x, hidden=None): + def forward(self, x, hidden=None, actions=None): h = F.relu(self.conv(x)) for block in self.blocks: h = F.relu(block(h)) h_p = self.head_p(h) + if actions is not None: + h_p = h_p.gather(-1, actions) h_v = self.head_v(h) return {'policy': h_p, 'value': torch.tanh(h_v)} diff --git a/handyrl/train.py b/handyrl/train.py index 5f8c43ae..a8fdbb67 100755 --- a/handyrl/train.py +++ b/handyrl/train.py @@ -88,6 +88,15 @@ def replace_none(a, b): progress = np.arange(ep['start'], ep['end'], dtype=np.float32)[..., np.newaxis] / ep['total'] + sampling_rate = 1 + sampled_action_count = int((9 - 1) * sampling_rate) + all_actions = np.arange(0, 9 - 1, dtype=np.int64) + sampled_actions = np.array([np.random.choice(all_actions, size=(sampled_action_count), replace=False) for _ in range(np.prod(act.shape[:-1]))]).reshape(*act.shape[:-1], -1) + sampled_actions += (sampled_actions >= act).astype(np.int64) + cat_actions = np.concatenate([act, sampled_actions], -1) + + amask = np.take_along_axis(amask, cat_actions, -1) + # pad each array if step length is short batch_steps = args['burn_in_steps'] + args['forward_steps'] if len(tmask) < batch_steps: @@ -104,12 +113,13 @@ def replace_none(a, b): omask = np.pad(omask, [(pad_len_b, pad_len_a), (0, 0), (0, 0)], 'constant', constant_values=0) amask = np.pad(amask, [(pad_len_b, pad_len_a), (0, 0), (0, 0)], 'constant', constant_values=1e32) progress = np.pad(progress, [(pad_len_b, pad_len_a), (0, 0)], 'constant', constant_values=1) + cat_actions = np.pad(cat_actions, [(pad_len_b, pad_len_a), (0, 0), (0, 0)], 'constant', constant_values=0) obss.append(obs) - datum.append((prob, v, act, oc, rew, ret, emask, tmask, omask, amask, progress)) + datum.append((prob, v, act, oc, rew, ret, emask, tmask, omask, amask, progress, cat_actions)) obs = to_torch(bimap_r(obs_zeros, rotate(obss), lambda _, o: np.array(o))) - prob, v, act, oc, rew, ret, emask, tmask, omask, amask, progress = [to_torch(np.array(val)) for val in zip(*datum)] + prob, v, act, oc, rew, ret, emask, tmask, omask, amask, progress, cat_actions = [to_torch(np.array(val)) for val in zip(*datum)] return { 'observation': obs, @@ -121,6 +131,7 @@ def replace_none(a, b): 'turn_mask': tmask, 'observation_mask': omask, 'action_mask': amask, 'progress': progress, + 'sampled_actions': cat_actions, } @@ -142,7 +153,8 @@ def forward_prediction(model, hidden, batch, args): if hidden is None: # feed-forward neural network obs = map_r(observations, lambda o: o.flatten(0, 2)) # (..., B * T * P or 1, ...) - outputs = model(obs, None) + sampled_actions = batch['sampled_actions'].flatten(0, 2) + outputs = model(obs, None, sampled_actions) outputs = map_r(outputs, lambda o: o.unflatten(0, batch_shape)) # (..., B, T, P or 1, ...) else: # sequential computation with RNN @@ -199,7 +211,10 @@ def compose_losses(outputs, log_selected_policies, total_advantages, targets, ba losses = {} dcnt = tmasks.sum().item() - losses['p'] = (-log_selected_policies * total_advantages).mul(tmasks).sum() + p_selected_loss = -outputs['policy'][:, :, :, :1].mul(total_advantages).mul(tmasks).sum() + p_prob_loss = outputs['policy'].mul(F.softmax(outputs['policy'].detach(), -1) * targets['mod_ratio']).mul(total_advantages).mul(tmasks).sum() + losses['p'] = p_selected_loss + p_prob_loss + if 'value' in outputs: losses['v'] = ((outputs['value'] - targets['value']) ** 2).mul(omasks).sum() / 2 if 'return' in outputs: @@ -229,7 +244,17 @@ def compute_loss(batch, model, hidden, args): clip_rho_threshold, clip_c_threshold = 1.0, 1.0 log_selected_b_policies = torch.log(torch.clamp(batch['selected_prob'], 1e-16, 1)) * emasks - log_selected_t_policies = F.log_softmax(outputs['policy'], dim=-1).gather(-1, actions) * emasks + + targets = {} + + sampled_count = outputs['policy'].size(-1) + total_count = 9 + naive_selected_t_policies = F.softmax(outputs['policy'], dim=-1) + estimated_sampled_action_b_prob = (1 - batch['selected_prob']) * ((sampled_count - 1) / (total_count - 1)) + targets['mod_ratio'] = batch['selected_prob'] + estimated_sampled_action_b_prob + selected_t_policies = naive_selected_t_policies * targets['mod_ratio'] + + log_selected_t_policies = torch.log(torch.clamp(selected_t_policies, 1e-16, 1)) * emasks # thresholds of importance sampling log_rhos = log_selected_t_policies.detach() - log_selected_b_policies @@ -248,7 +273,6 @@ def compute_loss(batch, model, hidden, args): outputs_nograd['value'] = values_nograd * emasks + batch['outcome'] * (1 - emasks) # compute targets and advantage - targets = {} advantages = {} value_args = outputs_nograd.get('value', None), batch['outcome'], None, args['lambda'], 1, clipped_rhos, cs, value_target_masks @@ -324,7 +348,7 @@ def __init__(self, args, model): self.args = args self.gpu = torch.cuda.device_count() self.model = model - self.default_lr = 3e-8 + self.default_lr = 3e-6 self.data_cnt_ema = self.args['batch_size'] * self.args['forward_steps'] self.params = list(self.model.parameters()) lr = self.default_lr * self.data_cnt_ema