Skip to content

Commit

Permalink
fix the way calculate mask
Browse files Browse the repository at this point in the history
  • Loading branch information
quangr committed Feb 1, 2023
1 parent 7ad95f4 commit c79d66f
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions cleanrl/ppo_continuous_action_envpool_xla_jax_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ def compute_gae(
)
return storage, agent_state

def ppo_loss(params, x, a, logp, mb_advantages, mb_returns, truncated):
def ppo_loss(params, x, a, logp, mb_advantages, mb_returns, mask):
newlogprob, entropy, newvalue = get_action_and_value2(params, x, a)
logratio = newlogprob - logp
ratio = jnp.exp(logratio)
Expand All @@ -468,10 +468,10 @@ def ppo_loss(params, x, a, logp, mb_advantages, mb_returns, truncated):
pg_loss1 = -mb_advantages * ratio
pg_loss2 = -mb_advantages * jnp.clip(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
# mask truncated state
pg_loss = (jnp.maximum(pg_loss1, pg_loss2) * (1 - truncated)).sum() / (1 - truncated).sum()
pg_loss = (jnp.maximum(pg_loss1, pg_loss2) * (1 - mask)).sum() / (1 - mask).sum()

# Value loss
v_loss = (((newvalue - mb_returns) * (1 - truncated)) ** 2).sum() / (1 - truncated).sum()
v_loss = (((newvalue - mb_returns) * (1 - mask)) ** 2).sum() / (1 - mask).sum()

entropy_loss = entropy.mean()
loss = pg_loss + v_loss * args.vf_coef
Expand Down Expand Up @@ -513,7 +513,7 @@ def update_minibatch(agent_state, minibatch):
minibatch.logprobs,
minibatch.advantages,
minibatch.returns,
minibatch.truncated,
1 - (1 - minibatch.truncated) * (1 - minibatch.dones),
)
agent_state = agent_state.apply_gradients(grads=grads)
return agent_state, (
Expand Down

0 comments on commit c79d66f

Please sign in to comment.