diff --git a/cleanrl/dqn.py b/cleanrl/dqn.py index 081efe871..fcfafa6ea 100644 --- a/cleanrl/dqn.py +++ b/cleanrl/dqn.py @@ -214,7 +214,8 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int): if global_step % args.target_network_frequency == 0: for target_network_param, q_network_param in zip(target_network.parameters(), q_network.parameters()): target_network_param.data.copy_( - args.tau * q_network_param.data + (1. - args.tau) * target_network_param.data) + args.tau * q_network_param.data + (1.0 - args.tau) * target_network_param.data + ) if args.save_model: model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model" diff --git a/cleanrl/dqn_atari.py b/cleanrl/dqn_atari.py index 7202c970a..e0e5a2b4d 100644 --- a/cleanrl/dqn_atari.py +++ b/cleanrl/dqn_atari.py @@ -236,7 +236,8 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int): if global_step % args.target_network_frequency == 0: for target_network_param, q_network_param in zip(target_network.parameters(), q_network.parameters()): target_network_param.data.copy_( - args.tau * q_network_param.data + (1. - args.tau) * target_network_param.data) + args.tau * q_network_param.data + (1.0 - args.tau) * target_network_param.data + ) if args.save_model: model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model" diff --git a/cleanrl/dqn_atari_jax.py b/cleanrl/dqn_atari_jax.py index 1cea0bee2..12a4e16ae 100644 --- a/cleanrl/dqn_atari_jax.py +++ b/cleanrl/dqn_atari_jax.py @@ -264,7 +264,9 @@ def mse_loss(params): # update target network if global_step % args.target_network_frequency == 0: - q_state = q_state.replace(target_params=optax.incremental_update(q_state.params, q_state.target_params, args.tau)) + q_state = q_state.replace( + target_params=optax.incremental_update(q_state.params, q_state.target_params, args.tau) + ) if args.save_model: model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model" diff --git a/cleanrl/dqn_jax.py b/cleanrl/dqn_jax.py index e5802a8ee..82c05499e 100644 --- a/cleanrl/dqn_jax.py +++ b/cleanrl/dqn_jax.py @@ -236,7 +236,9 @@ def mse_loss(params): # update target network if global_step % args.target_network_frequency == 0: - q_state = q_state.replace(target_params=optax.incremental_update(q_state.params, q_state.target_params, args.tau)) + q_state = q_state.replace( + target_params=optax.incremental_update(q_state.params, q_state.target_params, args.tau) + ) if args.save_model: model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"