Skip to content

Commit

Permalink
Merge pull request #261 from YuriCat/feature/safe_target_mask
Browse files Browse the repository at this point in the history
feature: apply lambda=1 in the timestep that there is no value output
  • Loading branch information
YuriCat authored Jan 24, 2023
2 parents f20018c + e33a227 commit cf2ef02
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 13 deletions.
27 changes: 16 additions & 11 deletions handyrl/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,38 +17,40 @@ def monte_carlo(values, returns):
return returns, returns - values


def temporal_difference(values, returns, rewards, lmb, gamma):
def temporal_difference(values, returns, rewards, lambda_, gamma):
target_values = deque([returns[:, -1]])
for i in range(values.size(1) - 2, -1, -1):
reward = rewards[:, i] if rewards is not None else 0
target_values.appendleft(reward + gamma * ((1 - lmb) * values[:, i + 1] + lmb * target_values[0]))
lamb = lambda_[:, i + 1]
target_values.appendleft(reward + gamma * ((1 - lamb) * values[:, i + 1] + lamb * target_values[0]))

target_values = torch.stack(tuple(target_values), dim=1)

return target_values, target_values - values


def upgo(values, returns, rewards, lmb, gamma):
def upgo(values, returns, rewards, lambda_, gamma):
target_values = deque([returns[:, -1]])
for i in range(values.size(1) - 2, -1, -1):
value = values[:, i + 1]
reward = rewards[:, i] if rewards is not None else 0
target_values.appendleft(reward + gamma * torch.max(value, (1 - lmb) * value + lmb * target_values[0]))
lamb = lambda_[:, i + 1]
target_values.appendleft(reward + gamma * torch.max(value, (1 - lamb) * value + lamb * target_values[0]))

target_values = torch.stack(tuple(target_values), dim=1)

return target_values, target_values - values


def vtrace(values, returns, rewards, lmb, gamma, rhos, cs):
def vtrace(values, returns, rewards, lambda_, gamma, rhos, cs):
rewards = rewards if rewards is not None else 0
values_t_plus_1 = torch.cat([values[:, 1:], returns[:, -1:]], dim=1)
deltas = rhos * (rewards + gamma * values_t_plus_1 - values)

# compute Vtrace value target recursively
vs_minus_v_xs = deque([deltas[:, -1]])
for i in range(values.size(1) - 2, -1, -1):
vs_minus_v_xs.appendleft(deltas[:, i] + gamma * lmb * cs[:, i] * vs_minus_v_xs[0])
vs_minus_v_xs.appendleft(deltas[:, i] + gamma * lambda_[:, i + 1] * cs[:, i] * vs_minus_v_xs[0])

vs_minus_v_xs = torch.stack(tuple(vs_minus_v_xs), dim=1)
vs = vs_minus_v_xs + values
Expand All @@ -58,18 +60,21 @@ def vtrace(values, returns, rewards, lmb, gamma, rhos, cs):
return vs, advantages


def compute_target(algorithm, values, returns, rewards, lmb, gamma, rhos, cs):
def compute_target(algorithm, values, returns, rewards, lmb, gamma, rhos, cs, masks):
if values is None:
# In the absence of a baseline, Monte Carlo returns are used.
return returns, returns

if algorithm == 'MC':
return monte_carlo(values, returns)
elif algorithm == 'TD':
return temporal_difference(values, returns, rewards, lmb, gamma)

lambda_ = lmb + (1 - lmb) * (1 - masks)

if algorithm == 'TD':
return temporal_difference(values, returns, rewards, lambda_, gamma)
elif algorithm == 'UPGO':
return upgo(values, returns, rewards, lmb, gamma)
return upgo(values, returns, rewards, lambda_, gamma)
elif algorithm == 'VTRACE':
return vtrace(values, returns, rewards, lmb, gamma, rhos, cs)
return vtrace(values, returns, rewards, lambda_, gamma, rhos, cs)
else:
print('No algorithm named %s' % algorithm)
7 changes: 5 additions & 2 deletions handyrl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,8 @@ def compute_loss(batch, model, hidden, args):
actions = batch['action']
emasks = batch['episode_mask']
omasks = batch['observation_mask']
value_target_masks, return_target_masks = omasks, omasks

clip_rho_threshold, clip_c_threshold = 1.0, 1.0

log_selected_b_policies = torch.log(torch.clamp(batch['selected_prob'], 1e-16, 1)) * emasks
Expand All @@ -243,14 +245,15 @@ def compute_loss(batch, model, hidden, args):
values_nograd_opponent = -torch.flip(values_nograd, dims=[2])
omasks_opponent = torch.flip(omasks, dims=[2])
values_nograd = (values_nograd * omasks + values_nograd_opponent * omasks_opponent) / (omasks + omasks_opponent + 1e-8)
value_target_masks = torch.clamp(omasks + omasks_opponent, 0, 1)
outputs_nograd['value'] = values_nograd * emasks + batch['outcome'] * (1 - emasks)

# compute targets and advantage
targets = {}
advantages = {}

value_args = outputs_nograd.get('value', None), batch['outcome'], None, args['lambda'], 1, clipped_rhos, cs
return_args = outputs_nograd.get('return', None), batch['return'], batch['reward'], args['lambda'], args['gamma'], clipped_rhos, cs
value_args = outputs_nograd.get('value', None), batch['outcome'], None, args['lambda'], 1, clipped_rhos, cs, value_target_masks
return_args = outputs_nograd.get('return', None), batch['return'], batch['reward'], args['lambda'], args['gamma'], clipped_rhos, cs, return_target_masks

targets['value'], advantages['value'] = compute_target(args['value_target'], *value_args)
targets['return'], advantages['return'] = compute_target(args['value_target'], *return_args)
Expand Down

0 comments on commit cf2ef02

Please sign in to comment.