Skip to content

Commit

Permalink
JAX TD3 prototype (#225)
Browse files Browse the repository at this point in the history
* initiali implementation, needs testing

* use clipped noise, combine qf_nets

* fully unify qf applies

* add correct clip limits

* run pre-commit

* Fix bugs, now it runs

* running tests

* Minor edit

* quick fix

* tests

* remove comment and move splitting to update critic

* one fix and some debugging

* remove debugs

* updating second critic, good result

* precommit

* update docs

* typo

* update docs

* update

* update test cases

* add docs

Co-authored-by: Costa Huang <[email protected]>
  • Loading branch information
joaogui1 and vwxyzjn authored Jul 31, 2022
1 parent 7ce655d commit 5bfdd45
Show file tree
Hide file tree
Showing 21 changed files with 424 additions and 11 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ You may also use a prebuilt development environment hosted in Gitpod:
|[Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf) | [`ddpg_continuous_action.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ddpg_continuous_action.py), [docs](https://docs.cleanrl.dev/rl-algorithms/ddpg/#ddpg_continuous_actionpy) |
| | [`ddpg_continuous_action_jax.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ddpg_continuous_action_jax.py), [docs](https://docs.cleanrl.dev/rl-algorithms/ddpg/#ddpg_continuous_action_jaxpy)
|[Twin Delayed Deep Deterministic Policy Gradient (TD3)](https://arxiv.org/pdf/1802.09477.pdf) | [`td3_continuous_action.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/td3_continuous_action.py), [docs](https://docs.cleanrl.dev/rl-algorithms/td3/#td3_continuous_actionpy) |
| | [`td3_continuous_action_jax.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/td3_continuous_action_jax.py), [docs](/rl-algorithms/td3/#td3_continuous_action_jaxpy) |
|[Phasic Policy Gradient (PPG)](https://arxiv.org/abs/2009.04416) | [`ppg_procgen.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppg_procgen.py), [docs](https://docs.cleanrl.dev/rl-algorithms/ppg/#ppg_procgenpy) |

## Open RL Benchmark
Expand Down
11 changes: 10 additions & 1 deletion benchmark/td3.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,13 @@ OMP_NUM_THREADS=1 xvfb-run -a python -m cleanrl_utils.benchmark \
--env-ids HalfCheetah-v2 Walker2d-v2 Hopper-v2 InvertedPendulum-v2 Humanoid-v2 Pusher-v2 \
--command "poetry run python cleanrl/td3_continuous_action.py --track --capture-video" \
--num-seeds 3 \
--workers 3
--workers 3

poetry install -E "mujoco jax"
poetry run pip install --upgrade "jax[cuda]==0.3.14" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
poetry run python -c "import mujoco_py"
OMP_NUM_THREADS=1 xvfb-run -a poetry run python -m cleanrl_utils.benchmark \
--env-ids HalfCheetah-v2 Walker2d-v2 Hopper-v2 \
--command "poetry run python cleanrl/td3_continuous_action_jax.py --track --capture-video" \
--num-seeds 3 \
--workers 1
324 changes: 324 additions & 0 deletions cleanrl/td3_continuous_action_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,324 @@
import argparse
import os
import random
import time
from distutils.util import strtobool
from typing import Sequence

import flax
import flax.linen as nn
import gym
import jax
import jax.numpy as jnp
import numpy as np
import optax
from flax.training.train_state import TrainState
from stable_baselines3.common.buffers import ReplayBuffer
from torch.utils.tensorboard import SummaryWriter


def parse_args():
# fmt: off
parser = argparse.ArgumentParser()
parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"),
help="the name of this experiment")
parser.add_argument("--seed", type=int, default=1,
help="seed of the experiment")
parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="if toggled, `torch.backends.cudnn.deterministic=False`")
parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="if toggled, cuda will be enabled by default")
parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="if toggled, this experiment will be tracked with Weights and Biases")
parser.add_argument("--wandb-project-name", type=str, default="cleanRL",
help="the wandb's project name")
parser.add_argument("--wandb-entity", type=str, default=None,
help="the entity (team) of wandb's project")
parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="weather to capture videos of the agent performances (check out `videos` folder)")

# Algorithm specific arguments
parser.add_argument("--env-id", type=str, default="HalfCheetah-v2",
help="the id of the environment")
parser.add_argument("--total-timesteps", type=int, default=1000000,
help="total timesteps of the experiments")
parser.add_argument("--learning-rate", type=float, default=3e-4,
help="the learning rate of the optimizer")
parser.add_argument("--buffer-size", type=int, default=int(1e6),
help="the replay memory buffer size")
parser.add_argument("--gamma", type=float, default=0.99,
help="the discount factor gamma")
parser.add_argument("--tau", type=float, default=0.005,
help="target smoothing coefficient (default: 0.005)")
parser.add_argument("--policy-noise", type=float, default=0.2,
help="the scale of policy noise")
parser.add_argument("--batch-size", type=int, default=256,
help="the batch size of sample from the reply memory")
parser.add_argument("--exploration-noise", type=float, default=0.1,
help="the scale of exploration noise")
parser.add_argument("--learning-starts", type=int, default=25e3,
help="timestep to start learning")
parser.add_argument("--policy-frequency", type=int, default=2,
help="the frequency of training policy (delayed)")
parser.add_argument("--noise-clip", type=float, default=0.5,
help="noise clip parameter of the Target Policy Smoothing Regularization")
args = parser.parse_args()
# fmt: on
return args


def make_env(env_id, seed, idx, capture_video, run_name):
def thunk():
env = gym.make(env_id)
env = gym.wrappers.RecordEpisodeStatistics(env)
if capture_video:
if idx == 0:
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
env.seed(seed)
env.action_space.seed(seed)
env.observation_space.seed(seed)
return env

return thunk


# ALGO LOGIC: initialize agent here:
class QNetwork(nn.Module):
@nn.compact
def __call__(self, x: jnp.ndarray, a: jnp.ndarray):
x = jnp.concatenate([x, a], -1)
x = nn.Dense(256)(x)
x = nn.relu(x)
x = nn.Dense(256)(x)
x = nn.relu(x)
x = nn.Dense(1)(x)
return x


class Actor(nn.Module):
action_dim: Sequence[int]
action_scale: Sequence[int]
action_bias: Sequence[int]

@nn.compact
def __call__(self, x):
x = nn.Dense(256)(x)
x = nn.relu(x)
x = nn.Dense(256)(x)
x = nn.relu(x)
x = nn.Dense(self.action_dim)(x)
x = nn.tanh(x)
x = x * self.action_scale + self.action_bias
return x


class TrainState(TrainState):
target_params: flax.core.FrozenDict


if __name__ == "__main__":
args = parse_args()
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
if args.track:
import wandb

wandb.init(
project=args.wandb_project_name,
entity=args.wandb_entity,
sync_tensorboard=True,
config=vars(args),
name=run_name,
monitor_gym=True,
save_code=True,
)
writer = SummaryWriter(f"runs/{run_name}")
writer.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)

# TRY NOT TO MODIFY: seeding
random.seed(args.seed)
np.random.seed(args.seed)
key = jax.random.PRNGKey(args.seed)
key, actor_key, qf1_key, qf2_key = jax.random.split(key, 4)

# env setup
envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)])
assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported"

max_action = float(envs.single_action_space.high[0])
envs.single_observation_space.dtype = np.float32
rb = ReplayBuffer(
args.buffer_size,
envs.single_observation_space,
envs.single_action_space,
device="cpu",
handle_timeout_termination=True,
)

# TRY NOT TO MODIFY: start the game
obs = envs.reset()
actor = Actor(
action_dim=np.prod(envs.single_action_space.shape),
action_scale=jnp.array((envs.action_space.high - envs.action_space.low) / 2.0),
action_bias=jnp.array((envs.action_space.high + envs.action_space.low) / 2.0),
)
actor_state = TrainState.create(
apply_fn=actor.apply,
params=actor.init(actor_key, obs),
target_params=actor.init(actor_key, obs),
tx=optax.adam(learning_rate=args.learning_rate),
)
qf = QNetwork()
qf1_state = TrainState.create(
apply_fn=qf.apply,
params=qf.init(qf1_key, obs, envs.action_space.sample()),
target_params=qf.init(qf1_key, obs, envs.action_space.sample()),
tx=optax.adam(learning_rate=args.learning_rate),
)
qf2_state = TrainState.create(
apply_fn=qf.apply,
params=qf.init(qf2_key, obs, envs.action_space.sample()),
target_params=qf.init(qf2_key, obs, envs.action_space.sample()),
tx=optax.adam(learning_rate=args.learning_rate),
)
actor.apply = jax.jit(actor.apply)
qf.apply = jax.jit(qf.apply)

@jax.jit
def update_critic(
actor_state: TrainState,
qf1_state: TrainState,
qf2_state: TrainState,
observations: np.ndarray,
actions: np.ndarray,
next_observations: np.ndarray,
rewards: np.ndarray,
dones: np.ndarray,
key: jnp.ndarray,
):
# TODO Maybe pre-generate a lot of random keys
# also check https://jax.readthedocs.io/en/latest/jax.random.html
key, noise_key = jax.random.split(key, 2)
clipped_noise = jnp.clip(
(jax.random.normal(noise_key, actions[0].shape) * args.policy_noise),
-args.noise_clip,
args.noise_clip,
)
next_state_actions = jnp.clip(
actor.apply(actor_state.target_params, next_observations) + clipped_noise,
envs.single_action_space.low[0],
envs.single_action_space.high[0],
)
qf1_next_target = qf.apply(qf1_state.target_params, next_observations, next_state_actions).reshape(-1)
qf2_next_target = qf.apply(qf2_state.target_params, next_observations, next_state_actions).reshape(-1)
min_qf_next_target = jnp.minimum(qf1_next_target, qf2_next_target)
next_q_value = (rewards + (1 - dones) * args.gamma * (min_qf_next_target)).reshape(-1)

def mse_loss(params):
qf_a_values = qf.apply(params, observations, actions).squeeze()
return ((qf_a_values - next_q_value) ** 2).mean(), qf_a_values.mean()

(qf1_loss_value, qf1_a_values), grads1 = jax.value_and_grad(mse_loss, has_aux=True)(qf1_state.params)
(qf2_loss_value, qf2_a_values), grads2 = jax.value_and_grad(mse_loss, has_aux=True)(qf2_state.params)
qf1_state = qf1_state.apply_gradients(grads=grads1)
qf2_state = qf2_state.apply_gradients(grads=grads2)

return (qf1_state, qf2_state), (qf1_loss_value, qf2_loss_value), (qf1_a_values, qf2_a_values), key

@jax.jit
def update_actor(
actor_state: TrainState,
qf1_state: TrainState,
qf2_state: TrainState,
observations: np.ndarray,
):
def actor_loss(params):
return -qf.apply(qf1_state.params, observations, actor.apply(params, observations)).mean()

actor_loss_value, grads = jax.value_and_grad(actor_loss)(actor_state.params)
actor_state = actor_state.apply_gradients(grads=grads)
actor_state = actor_state.replace(
target_params=optax.incremental_update(actor_state.params, actor_state.target_params, args.tau)
)

qf1_state = qf1_state.replace(
target_params=optax.incremental_update(qf1_state.params, qf1_state.target_params, args.tau)
)
qf2_state = qf2_state.replace(
target_params=optax.incremental_update(qf2_state.params, qf2_state.target_params, args.tau)
)
return actor_state, (qf1_state, qf2_state), actor_loss_value

start_time = time.time()
for global_step in range(args.total_timesteps):
# ALGO LOGIC: put action logic here
if global_step < args.learning_starts:
actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
else:
actions = actor.apply(actor_state.params, obs)
actions = np.array(
[
(
jax.device_get(actions)[0]
+ np.random.normal(0, max_action * args.exploration_noise, size=envs.single_action_space.shape[0])
).clip(envs.single_action_space.low, envs.single_action_space.high)
]
)

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, dones, infos = envs.step(actions)

# TRY NOT TO MODIFY: record rewards for plotting purposes
for info in infos:
if "episode" in info.keys():
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
break

# TRY NOT TO MODIFY: save data to replay buffer; handle `terminal_observation`
real_next_obs = next_obs.copy()
for idx, d in enumerate(dones):
if d:
real_next_obs[idx] = infos[idx]["terminal_observation"]
rb.add(obs, real_next_obs, actions, rewards, dones, infos)

# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
obs = next_obs

# ALGO LOGIC: training.
if global_step > args.learning_starts:
data = rb.sample(args.batch_size)

(qf1_state, qf2_state), (qf1_loss_value, qf2_loss_value), (qf1_a_values, qf2_a_values), key = update_critic(
actor_state,
qf1_state,
qf2_state,
data.observations.numpy(),
data.actions.numpy(),
data.next_observations.numpy(),
data.rewards.flatten().numpy(),
data.dones.flatten().numpy(),
key,
)

if global_step % args.policy_frequency == 0:
actor_state, (qf1_state, qf2_state), actor_loss_value = update_actor(
actor_state,
qf1_state,
qf2_state,
data.observations.numpy(),
)
if global_step % 100 == 0:
writer.add_scalar("losses/qf1_loss", qf1_loss_value.item(), global_step)
writer.add_scalar("losses/qf2_loss", qf2_loss_value.item(), global_step)
writer.add_scalar("losses/qf1_values", qf1_a_values.item(), global_step)
writer.add_scalar("losses/qf2_values", qf2_a_values.item(), global_step)
writer.add_scalar("losses/actor_loss", actor_loss_value.item(), global_step)
print("SPS:", int(global_step / (time.time() - start_time)))
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

envs.close()
writer.close()
Binary file modified docs/rl-algorithms/ddpg-jax/HalfCheetah-v2-time.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/rl-algorithms/ddpg-jax/HalfCheetah-v2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/rl-algorithms/ddpg-jax/Hopper-v2-time.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/rl-algorithms/ddpg-jax/Hopper-v2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/rl-algorithms/ddpg-jax/Walker2d-v2-time.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/rl-algorithms/ddpg-jax/Walker2d-v2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
14 changes: 8 additions & 6 deletions docs/rl-algorithms/ddpg.md
Original file line number Diff line number Diff line change
Expand Up @@ -278,16 +278,18 @@ To run benchmark experiments, see :material-github: [benchmark/ddpg.sh](https://

Below are the average episodic returns for [`ddpg_continuous_action_jax.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ddpg_continuous_action_jax.py) (3 random seeds). To ensure the quality of the implementation, we compared the results against (Fujimoto et al., 2018)[^2].

| Environment | [`ddpg_continuous_action_jax.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ddpg_continuous_action_jax.py) | [`ddpg_continuous_action.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ddpg_continuous_action.py) | [`OurDDPG.py`](https://github.com/sfujim/TD3/blob/master/OurDDPG.py) (Fujimoto et al., 2018, Table 1)[^2] |
| ----------- | ----------- | ----------- | ----------- |
| HalfCheetah | 9910.53 ± 673.49 | 9382.32 ± 1395.52 |8577.29 |
| Walker2d | 1397.60 ± 677.12 | 1598.35 ± 862 | 3098.11 |
| Hopper | 1603.5 ± 727.281 | 1313.43 ± 684.46 | 1860.02 |
| Environment | [`ddpg_continuous_action_jax.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ddpg_continuous_action_jax.py) (RTX 3060 TI) | [`ddpg_continuous_action_jax.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ddpg_continuous_action_jax.py) (VM w/ TPU) | [`ddpg_continuous_action.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ddpg_continuous_action.py) (RTX 2060) | [`OurDDPG.py`](https://github.com/sfujim/TD3/blob/master/OurDDPG.py) (Fujimoto et al., 2018, Table 1)[^2] |
| ----------- | ----------- | ----------- | ----------- | ----------- |
| HalfCheetah | 9910.53 ± 673.49 | 9790.72 ± 1494.85 | 9382.32 ± 1395.52 |8577.29 |
| Walker2d | 1397.60 ± 677.12 | 1314.83 ± 689.71 | 1598.35 ± 862 | 3098.11 |
| Hopper | 1603.5 ± 727.281 | 1602.20 ± 696.11 | 1313.43 ± 684.46 | 1860.02 |


???+ info

Note that we ran the [`ddpg_continuous_action_jax.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ddpg_continuous_action_jax.py) experiments with RTX 3060 Ti (~810 SPS) and [`ddpg_continuous_action.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ddpg_continuous_action.py) experiments with RTX 2060 (~241 SPS). Using RTX 3060 Ti w/ [`ddpg_continuous_action.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ddpg_continuous_action.py) brings the SPS from 241 to 325, according to [this report](https://wandb.ai/costa-huang/cleanRL/reports/Torch-DDPG-2060-vs-3060ti---VmlldzoyMzA1NDYy) meaning that under the same hardware, [`ddpg_continuous_action_jax.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ddpg_continuous_action_jax.py) would be **roughly 810/241=2.5x faster**. However, because of the overhead of `--capture-video` that both scripts suffer, we suspect [`ddpg_continuous_action_jax.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ddpg_continuous_action_jax.py) would be 3x-4x faster when `--capture-video` is disabled.
Note that the experiments were conducted on different hardwares, so your mileage might vary. This inconsistency is because 1) re-running expeirments on the same hardware is computationally expensive and 2) requiring the same hardware is not inclusive nor feasible to other contributors who might have different hardwares.

That said, we roughly expect to see a 2-4x speed improvement from using [`ddpg_continuous_action_jax.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ddpg_continuous_action_jax.py) under the same hardware. And if you disable the `--capture-video` overhead, the speed improvement will be even higher.


Learning curves:
Expand Down
Loading

1 comment on commit 5bfdd45

@vercel
Copy link

@vercel vercel bot commented on 5bfdd45 Jul 31, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Successfully deployed to the following URLs:

cleanrl – ./

cleanrl.vercel.app
cleanrl-git-master-vwxyzjn.vercel.app
docs.cleanrl.dev
cleanrl-vwxyzjn.vercel.app

Please sign in to comment.