Skip to content

Commit

Permalink
Separate q network target update
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin committed Sep 16, 2022
1 parent 9704f1d commit f0cc8ff
Showing 1 changed file with 18 additions and 1 deletion.
19 changes: 18 additions & 1 deletion cleanrl/td3_droq_continuous_action_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,20 @@ def actor_loss(params):
)
)

# qf1_state = qf1_state.replace(
# target_params=optax.incremental_update(
# qf1_state.params, qf1_state.target_params, args.tau
# )
# )
# qf2_state = qf2_state.replace(
# target_params=optax.incremental_update(
# qf2_state.params, qf2_state.target_params, args.tau
# )
# )
return actor_state, (qf1_state, qf2_state), actor_loss_value, key

@jax.jit
def update_q_target_networks(qf1_state, qf2_state):
qf1_state = qf1_state.replace(
target_params=optax.incremental_update(
qf1_state.params, qf1_state.target_params, args.tau
Expand All @@ -353,7 +367,7 @@ def actor_loss(params):
qf2_state.params, qf2_state.target_params, args.tau
)
)
return actor_state, (qf1_state, qf2_state), actor_loss_value, key
return qf1_state, qf2_state

start_time = time.time()
n_updates = 0
Expand Down Expand Up @@ -433,6 +447,9 @@ def actor_loss(params):
key,
)

# TODO: check if we need to update actor target too
qf1_state, qf2_state = update_q_target_networks(qf1_state, qf2_state)

if n_updates % args.policy_frequency == 0:
(
actor_state,
Expand Down

0 comments on commit f0cc8ff

Please sign in to comment.