-
Notifications
You must be signed in to change notification settings - Fork 706
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* initiali implementation, needs testing * use clipped noise, combine qf_nets * fully unify qf applies * add correct clip limits * run pre-commit * Fix bugs, now it runs * running tests * Minor edit * quick fix * tests * remove comment and move splitting to update critic * one fix and some debugging * remove debugs * updating second critic, good result * precommit * update docs * typo * update docs * update * update test cases * add docs Co-authored-by: Costa Huang <[email protected]>
- Loading branch information
Showing
21 changed files
with
424 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,324 @@ | ||
import argparse | ||
import os | ||
import random | ||
import time | ||
from distutils.util import strtobool | ||
from typing import Sequence | ||
|
||
import flax | ||
import flax.linen as nn | ||
import gym | ||
import jax | ||
import jax.numpy as jnp | ||
import numpy as np | ||
import optax | ||
from flax.training.train_state import TrainState | ||
from stable_baselines3.common.buffers import ReplayBuffer | ||
from torch.utils.tensorboard import SummaryWriter | ||
|
||
|
||
def parse_args(): | ||
# fmt: off | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"), | ||
help="the name of this experiment") | ||
parser.add_argument("--seed", type=int, default=1, | ||
help="seed of the experiment") | ||
parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, | ||
help="if toggled, `torch.backends.cudnn.deterministic=False`") | ||
parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, | ||
help="if toggled, cuda will be enabled by default") | ||
parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, | ||
help="if toggled, this experiment will be tracked with Weights and Biases") | ||
parser.add_argument("--wandb-project-name", type=str, default="cleanRL", | ||
help="the wandb's project name") | ||
parser.add_argument("--wandb-entity", type=str, default=None, | ||
help="the entity (team) of wandb's project") | ||
parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, | ||
help="weather to capture videos of the agent performances (check out `videos` folder)") | ||
|
||
# Algorithm specific arguments | ||
parser.add_argument("--env-id", type=str, default="HalfCheetah-v2", | ||
help="the id of the environment") | ||
parser.add_argument("--total-timesteps", type=int, default=1000000, | ||
help="total timesteps of the experiments") | ||
parser.add_argument("--learning-rate", type=float, default=3e-4, | ||
help="the learning rate of the optimizer") | ||
parser.add_argument("--buffer-size", type=int, default=int(1e6), | ||
help="the replay memory buffer size") | ||
parser.add_argument("--gamma", type=float, default=0.99, | ||
help="the discount factor gamma") | ||
parser.add_argument("--tau", type=float, default=0.005, | ||
help="target smoothing coefficient (default: 0.005)") | ||
parser.add_argument("--policy-noise", type=float, default=0.2, | ||
help="the scale of policy noise") | ||
parser.add_argument("--batch-size", type=int, default=256, | ||
help="the batch size of sample from the reply memory") | ||
parser.add_argument("--exploration-noise", type=float, default=0.1, | ||
help="the scale of exploration noise") | ||
parser.add_argument("--learning-starts", type=int, default=25e3, | ||
help="timestep to start learning") | ||
parser.add_argument("--policy-frequency", type=int, default=2, | ||
help="the frequency of training policy (delayed)") | ||
parser.add_argument("--noise-clip", type=float, default=0.5, | ||
help="noise clip parameter of the Target Policy Smoothing Regularization") | ||
args = parser.parse_args() | ||
# fmt: on | ||
return args | ||
|
||
|
||
def make_env(env_id, seed, idx, capture_video, run_name): | ||
def thunk(): | ||
env = gym.make(env_id) | ||
env = gym.wrappers.RecordEpisodeStatistics(env) | ||
if capture_video: | ||
if idx == 0: | ||
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") | ||
env.seed(seed) | ||
env.action_space.seed(seed) | ||
env.observation_space.seed(seed) | ||
return env | ||
|
||
return thunk | ||
|
||
|
||
# ALGO LOGIC: initialize agent here: | ||
class QNetwork(nn.Module): | ||
@nn.compact | ||
def __call__(self, x: jnp.ndarray, a: jnp.ndarray): | ||
x = jnp.concatenate([x, a], -1) | ||
x = nn.Dense(256)(x) | ||
x = nn.relu(x) | ||
x = nn.Dense(256)(x) | ||
x = nn.relu(x) | ||
x = nn.Dense(1)(x) | ||
return x | ||
|
||
|
||
class Actor(nn.Module): | ||
action_dim: Sequence[int] | ||
action_scale: Sequence[int] | ||
action_bias: Sequence[int] | ||
|
||
@nn.compact | ||
def __call__(self, x): | ||
x = nn.Dense(256)(x) | ||
x = nn.relu(x) | ||
x = nn.Dense(256)(x) | ||
x = nn.relu(x) | ||
x = nn.Dense(self.action_dim)(x) | ||
x = nn.tanh(x) | ||
x = x * self.action_scale + self.action_bias | ||
return x | ||
|
||
|
||
class TrainState(TrainState): | ||
target_params: flax.core.FrozenDict | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parse_args() | ||
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" | ||
if args.track: | ||
import wandb | ||
|
||
wandb.init( | ||
project=args.wandb_project_name, | ||
entity=args.wandb_entity, | ||
sync_tensorboard=True, | ||
config=vars(args), | ||
name=run_name, | ||
monitor_gym=True, | ||
save_code=True, | ||
) | ||
writer = SummaryWriter(f"runs/{run_name}") | ||
writer.add_text( | ||
"hyperparameters", | ||
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), | ||
) | ||
|
||
# TRY NOT TO MODIFY: seeding | ||
random.seed(args.seed) | ||
np.random.seed(args.seed) | ||
key = jax.random.PRNGKey(args.seed) | ||
key, actor_key, qf1_key, qf2_key = jax.random.split(key, 4) | ||
|
||
# env setup | ||
envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)]) | ||
assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" | ||
|
||
max_action = float(envs.single_action_space.high[0]) | ||
envs.single_observation_space.dtype = np.float32 | ||
rb = ReplayBuffer( | ||
args.buffer_size, | ||
envs.single_observation_space, | ||
envs.single_action_space, | ||
device="cpu", | ||
handle_timeout_termination=True, | ||
) | ||
|
||
# TRY NOT TO MODIFY: start the game | ||
obs = envs.reset() | ||
actor = Actor( | ||
action_dim=np.prod(envs.single_action_space.shape), | ||
action_scale=jnp.array((envs.action_space.high - envs.action_space.low) / 2.0), | ||
action_bias=jnp.array((envs.action_space.high + envs.action_space.low) / 2.0), | ||
) | ||
actor_state = TrainState.create( | ||
apply_fn=actor.apply, | ||
params=actor.init(actor_key, obs), | ||
target_params=actor.init(actor_key, obs), | ||
tx=optax.adam(learning_rate=args.learning_rate), | ||
) | ||
qf = QNetwork() | ||
qf1_state = TrainState.create( | ||
apply_fn=qf.apply, | ||
params=qf.init(qf1_key, obs, envs.action_space.sample()), | ||
target_params=qf.init(qf1_key, obs, envs.action_space.sample()), | ||
tx=optax.adam(learning_rate=args.learning_rate), | ||
) | ||
qf2_state = TrainState.create( | ||
apply_fn=qf.apply, | ||
params=qf.init(qf2_key, obs, envs.action_space.sample()), | ||
target_params=qf.init(qf2_key, obs, envs.action_space.sample()), | ||
tx=optax.adam(learning_rate=args.learning_rate), | ||
) | ||
actor.apply = jax.jit(actor.apply) | ||
qf.apply = jax.jit(qf.apply) | ||
|
||
@jax.jit | ||
def update_critic( | ||
actor_state: TrainState, | ||
qf1_state: TrainState, | ||
qf2_state: TrainState, | ||
observations: np.ndarray, | ||
actions: np.ndarray, | ||
next_observations: np.ndarray, | ||
rewards: np.ndarray, | ||
dones: np.ndarray, | ||
key: jnp.ndarray, | ||
): | ||
# TODO Maybe pre-generate a lot of random keys | ||
# also check https://jax.readthedocs.io/en/latest/jax.random.html | ||
key, noise_key = jax.random.split(key, 2) | ||
clipped_noise = jnp.clip( | ||
(jax.random.normal(noise_key, actions[0].shape) * args.policy_noise), | ||
-args.noise_clip, | ||
args.noise_clip, | ||
) | ||
next_state_actions = jnp.clip( | ||
actor.apply(actor_state.target_params, next_observations) + clipped_noise, | ||
envs.single_action_space.low[0], | ||
envs.single_action_space.high[0], | ||
) | ||
qf1_next_target = qf.apply(qf1_state.target_params, next_observations, next_state_actions).reshape(-1) | ||
qf2_next_target = qf.apply(qf2_state.target_params, next_observations, next_state_actions).reshape(-1) | ||
min_qf_next_target = jnp.minimum(qf1_next_target, qf2_next_target) | ||
next_q_value = (rewards + (1 - dones) * args.gamma * (min_qf_next_target)).reshape(-1) | ||
|
||
def mse_loss(params): | ||
qf_a_values = qf.apply(params, observations, actions).squeeze() | ||
return ((qf_a_values - next_q_value) ** 2).mean(), qf_a_values.mean() | ||
|
||
(qf1_loss_value, qf1_a_values), grads1 = jax.value_and_grad(mse_loss, has_aux=True)(qf1_state.params) | ||
(qf2_loss_value, qf2_a_values), grads2 = jax.value_and_grad(mse_loss, has_aux=True)(qf2_state.params) | ||
qf1_state = qf1_state.apply_gradients(grads=grads1) | ||
qf2_state = qf2_state.apply_gradients(grads=grads2) | ||
|
||
return (qf1_state, qf2_state), (qf1_loss_value, qf2_loss_value), (qf1_a_values, qf2_a_values), key | ||
|
||
@jax.jit | ||
def update_actor( | ||
actor_state: TrainState, | ||
qf1_state: TrainState, | ||
qf2_state: TrainState, | ||
observations: np.ndarray, | ||
): | ||
def actor_loss(params): | ||
return -qf.apply(qf1_state.params, observations, actor.apply(params, observations)).mean() | ||
|
||
actor_loss_value, grads = jax.value_and_grad(actor_loss)(actor_state.params) | ||
actor_state = actor_state.apply_gradients(grads=grads) | ||
actor_state = actor_state.replace( | ||
target_params=optax.incremental_update(actor_state.params, actor_state.target_params, args.tau) | ||
) | ||
|
||
qf1_state = qf1_state.replace( | ||
target_params=optax.incremental_update(qf1_state.params, qf1_state.target_params, args.tau) | ||
) | ||
qf2_state = qf2_state.replace( | ||
target_params=optax.incremental_update(qf2_state.params, qf2_state.target_params, args.tau) | ||
) | ||
return actor_state, (qf1_state, qf2_state), actor_loss_value | ||
|
||
start_time = time.time() | ||
for global_step in range(args.total_timesteps): | ||
# ALGO LOGIC: put action logic here | ||
if global_step < args.learning_starts: | ||
actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) | ||
else: | ||
actions = actor.apply(actor_state.params, obs) | ||
actions = np.array( | ||
[ | ||
( | ||
jax.device_get(actions)[0] | ||
+ np.random.normal(0, max_action * args.exploration_noise, size=envs.single_action_space.shape[0]) | ||
).clip(envs.single_action_space.low, envs.single_action_space.high) | ||
] | ||
) | ||
|
||
# TRY NOT TO MODIFY: execute the game and log data. | ||
next_obs, rewards, dones, infos = envs.step(actions) | ||
|
||
# TRY NOT TO MODIFY: record rewards for plotting purposes | ||
for info in infos: | ||
if "episode" in info.keys(): | ||
print(f"global_step={global_step}, episodic_return={info['episode']['r']}") | ||
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) | ||
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) | ||
break | ||
|
||
# TRY NOT TO MODIFY: save data to replay buffer; handle `terminal_observation` | ||
real_next_obs = next_obs.copy() | ||
for idx, d in enumerate(dones): | ||
if d: | ||
real_next_obs[idx] = infos[idx]["terminal_observation"] | ||
rb.add(obs, real_next_obs, actions, rewards, dones, infos) | ||
|
||
# TRY NOT TO MODIFY: CRUCIAL step easy to overlook | ||
obs = next_obs | ||
|
||
# ALGO LOGIC: training. | ||
if global_step > args.learning_starts: | ||
data = rb.sample(args.batch_size) | ||
|
||
(qf1_state, qf2_state), (qf1_loss_value, qf2_loss_value), (qf1_a_values, qf2_a_values), key = update_critic( | ||
actor_state, | ||
qf1_state, | ||
qf2_state, | ||
data.observations.numpy(), | ||
data.actions.numpy(), | ||
data.next_observations.numpy(), | ||
data.rewards.flatten().numpy(), | ||
data.dones.flatten().numpy(), | ||
key, | ||
) | ||
|
||
if global_step % args.policy_frequency == 0: | ||
actor_state, (qf1_state, qf2_state), actor_loss_value = update_actor( | ||
actor_state, | ||
qf1_state, | ||
qf2_state, | ||
data.observations.numpy(), | ||
) | ||
if global_step % 100 == 0: | ||
writer.add_scalar("losses/qf1_loss", qf1_loss_value.item(), global_step) | ||
writer.add_scalar("losses/qf2_loss", qf2_loss_value.item(), global_step) | ||
writer.add_scalar("losses/qf1_values", qf1_a_values.item(), global_step) | ||
writer.add_scalar("losses/qf2_values", qf2_a_values.item(), global_step) | ||
writer.add_scalar("losses/actor_loss", actor_loss_value.item(), global_step) | ||
print("SPS:", int(global_step / (time.time() - start_time))) | ||
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) | ||
|
||
envs.close() | ||
writer.close() |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
5bfdd45
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Successfully deployed to the following URLs:
cleanrl – ./
cleanrl.vercel.app
cleanrl-git-master-vwxyzjn.vercel.app
docs.cleanrl.dev
cleanrl-vwxyzjn.vercel.app