Skip to content

Commit

Permalink
Add SAC-discrete implementation and docs (#270)
Browse files Browse the repository at this point in the history
* add draft of SAC discrete implementation

* run pre-commit

* Use log softmax instead of author's log-pi code

* Revert to cleanrl SAC delay implementation (it's more stable)

* Remove docstrings and duplicate code

* Use correct clipreward wrapper

* fix bug in log softmax calculation

* adhere to cleanrl log_prob naming

* fix bug in entropy target calculation

* change layer initialization to match existing cleanrl codebase

* working minimal diff version

* implement original learning update frequency

* parameterize the entropy scale for autotuning

* add benchmarking script

* rename target entropy factor and set new default value

* add docs draft

* fix SAC-discrete links to work pre merge

* add preliminary result table for SAC-discrete

* clean up todos and add header

* minimize diff between sac_atari and sac_continuous

* add sac-discrete end2end test

* SAC-discrete docs rework

* Update SAC-discrete @100k results

* Fix doc links and unify naming in code

* update docs

* fix target update frequency (see PR #323)

* clarify comment regarding CNN encoder sharing

* fix benchmark installation

* fix eps in minimal diff version and improve code readability

* add docs for eps and finalize code

* use no_grad for actor Q-vals and re-use action-probs & log-probs in alpha loss

* update docs for new code and settings

* fix links to point to main branch

* update sac-discrete training plots

* new sac-d training plots

* update results table and fix link

* fix pong chart title

* add Jimmy Ba name as exception to code spell check

* change target_entropy_scale default value to same value as experiments

* remove blank line at end of pre-commit

Co-authored-by: Costa Huang <[email protected]>
  • Loading branch information
timoklein and vwxyzjn authored Jan 13, 2023
1 parent 30381ee commit c3fc57d
Show file tree
Hide file tree
Showing 11 changed files with 558 additions and 20 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ repos:
hooks:
- id: codespell
args:
- --ignore-words-list=nd,reacher,thist,ths,magent
- --ignore-words-list=nd,reacher,thist,ths,magent,Ba
- --skip=docs/css/termynal.css,docs/js/termynal.js,docs/get-started/CleanRL_Huggingface_Integration_Demo.ipynb
- repo: https://github.com/python-poetry/poetry
rev: 1.2.1
Expand Down Expand Up @@ -79,4 +79,4 @@ repos:
- id: poetry-export
name: poetry-export requirements-cloud.txt
args: ["--without-hashes", "-o", "requirements/requirements-cloud.txt", "--with", "cloud"]
stages: [manual]
stages: [manual]
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ poetry install --with atari
python cleanrl/dqn_atari.py --env-id BreakoutNoFrameskip-v4
python cleanrl/c51_atari.py --env-id BreakoutNoFrameskip-v4
python cleanrl/ppo_atari.py --env-id BreakoutNoFrameskip-v4
python cleanrl/sac_atari.py --env-id BreakoutNoFrameskip-v4
# NEW: 3-4x side-effects free speed up with envpool's atari (only available to linux)
poetry install --with envpool
Expand Down Expand Up @@ -129,14 +130,15 @@ You may also use a prebuilt development environment hosted in Gitpod:
| | [`ppo_pettingzoo_ma_atari.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_pettingzoo_ma_atari.py), [docs](https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_pettingzoo_ma_ataripy)
| | [`ppo_continuous_action_isaacgym.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_continuous_action_isaacgym/ppo_continuous_action_isaacgym.py), [docs](https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_continuous_action_isaacgympy)
|[Deep Q-Learning (DQN)](https://web.stanford.edu/class/psych209/Readings/MnihEtAlHassibis15NatureControlDeepRL.pdf) | [`dqn.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn.py), [docs](https://docs.cleanrl.dev/rl-algorithms/dqn/#dqnpy) |
| | [`dqn_atari.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn_atari.py), [docs](https://docs.cleanrl.dev/rl-algorithms/dqn/#dqn_ataripy) |
| | [`dqn_atari.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn_atari.py), [docs](https://docs.cleanrl.dev/rl-algorithms/dqn/#dqn_ataripy) |
| | [`dqn_jax.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn_jax.py), [docs](https://docs.cleanrl.dev/rl-algorithms/dqn/#dqn_jaxpy) |
| | [`dqn_atari_jax.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn_atari_jax.py), [docs](https://docs.cleanrl.dev/rl-algorithms/dqn/#dqn_atari_jaxpy) |
|[Categorical DQN (C51)](https://arxiv.org/pdf/1707.06887.pdf) | [`c51.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/c51.py), [docs](https://docs.cleanrl.dev/rl-algorithms/c51/#c51py) |
| | [`c51_atari.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/c51_atari.py), [docs](https://docs.cleanrl.dev/rl-algorithms/c51/#c51_ataripy) |
| | [`c51_jax.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/c51_jax.py), [docs](https://docs.cleanrl.dev/rl-algorithms/c51/#c51_jaxpy) |
| | [`c51_atari_jax.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/c51_atari_jax.py), [docs](https://docs.cleanrl.dev/rl-algorithms/c51/#c51_atari_jaxpy) |
|[Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf) | [`sac_continuous_action.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/sac_continuous_action.py), [docs](https://docs.cleanrl.dev/rl-algorithms/sac/#sac_continuous_actionpy) |
| | [`sac_atari.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/sac_atari.py), [docs](https://docs.cleanrl.dev/rl-algorithms/sac/#sac_atarinpy) |
|[Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf) | [`ddpg_continuous_action.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ddpg_continuous_action.py), [docs](https://docs.cleanrl.dev/rl-algorithms/ddpg/#ddpg_continuous_actionpy) |
| | [`ddpg_continuous_action_jax.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ddpg_continuous_action_jax.py), [docs](https://docs.cleanrl.dev/rl-algorithms/ddpg/#ddpg_continuous_action_jaxpy)
|[Twin Delayed Deep Deterministic Policy Gradient (TD3)](https://arxiv.org/pdf/1802.09477.pdf) | [`td3_continuous_action.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/td3_continuous_action.py), [docs](https://docs.cleanrl.dev/rl-algorithms/td3/#td3_continuous_actionpy) |
Expand Down
6 changes: 6 additions & 0 deletions benchmark/sac_atari.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
poetry install --with atari
OMP_NUM_THREADS=1 python -m cleanrl_utils.benchmark \
--env-ids PongNoFrameskip-v4 BreakoutNoFrameskip-v4 BeamRiderNoFrameskip-v4 \
--command "poetry run python cleanrl/sac_atari.py --cuda True --track" \
--num-seeds 3 \
--workers 2
342 changes: 342 additions & 0 deletions cleanrl/sac_atari.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,342 @@
# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/sac/#sac_ataripy
import argparse
import os
import random
import time
from distutils.util import strtobool

import gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from stable_baselines3.common.atari_wrappers import (
ClipRewardEnv,
EpisodicLifeEnv,
FireResetEnv,
MaxAndSkipEnv,
NoopResetEnv,
)
from stable_baselines3.common.buffers import ReplayBuffer
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter


def parse_args():
# fmt: off
parser = argparse.ArgumentParser()
parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"),
help="the name of this experiment")
parser.add_argument("--seed", type=int, default=1,
help="seed of the experiment")
parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="if toggled, `torch.backends.cudnn.deterministic=False`")
parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="if toggled, cuda will be enabled by default")
parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="if toggled, this experiment will be tracked with Weights and Biases")
parser.add_argument("--wandb-project-name", type=str, default="cleanRL",
help="the wandb's project name")
parser.add_argument("--wandb-entity", type=str, default=None,
help="the entity (team) of wandb's project")
parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="weather to capture videos of the agent performances (check out `videos` folder)")

# Algorithm specific arguments
parser.add_argument("--env-id", type=str, default="BeamRiderNoFrameskip-v4",
help="the id of the environment")
parser.add_argument("--total-timesteps", type=int, default=5000000,
help="total timesteps of the experiments")
parser.add_argument("--buffer-size", type=int, default=int(1e6),
help="the replay memory buffer size") # smaller than in original paper but evaluation is done only for 100k steps anyway
parser.add_argument("--gamma", type=float, default=0.99,
help="the discount factor gamma")
parser.add_argument("--tau", type=float, default=1.0,
help="target smoothing coefficient (default: 1)") # Default is 1 to perform replacement update
parser.add_argument("--batch-size", type=int, default=64,
help="the batch size of sample from the reply memory")
parser.add_argument("--learning-starts", type=int, default=2e4,
help="timestep to start learning")
parser.add_argument("--policy-lr", type=float, default=3e-4,
help="the learning rate of the policy network optimizer")
parser.add_argument("--q-lr", type=float, default=3e-4,
help="the learning rate of the Q network network optimizer")
parser.add_argument("--update-frequency", type=int, default=4,
help="the frequency of training updates")
parser.add_argument("--target-network-frequency", type=int, default=8000,
help="the frequency of updates for the target networks")
parser.add_argument("--alpha", type=float, default=0.2,
help="Entropy regularization coefficient.")
parser.add_argument("--autotune", type=lambda x:bool(strtobool(x)), default=True, nargs="?", const=True,
help="automatic tuning of the entropy coefficient")
parser.add_argument("--target-entropy-scale", type=float, default=0.89,
help="coefficient for scaling the autotune entropy target")
args = parser.parse_args()
# fmt: on
return args


def make_env(env_id, seed, idx, capture_video, run_name):
def thunk():
env = gym.make(env_id)
env = gym.wrappers.RecordEpisodeStatistics(env)
if capture_video:
if idx == 0:
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
env = NoopResetEnv(env, noop_max=30)
env = MaxAndSkipEnv(env, skip=4)
env = EpisodicLifeEnv(env)
if "FIRE" in env.unwrapped.get_action_meanings():
env = FireResetEnv(env)
env = ClipRewardEnv(env)
env = gym.wrappers.ResizeObservation(env, (84, 84))
env = gym.wrappers.GrayScaleObservation(env)
env = gym.wrappers.FrameStack(env, 4)
env.seed(seed)
env.action_space.seed(seed)
env.observation_space.seed(seed)
return env

return thunk


def layer_init(layer, bias_const=0.0):
nn.init.kaiming_normal_(layer.weight)
torch.nn.init.constant_(layer.bias, bias_const)
return layer


# ALGO LOGIC: initialize agent here:
# NOTE: Sharing a CNN encoder between Actor and Critics is not recommended for SAC without stopping actor gradients
# See the SAC+AE paper https://arxiv.org/abs/1910.01741 for more info
# TL;DR The actor's gradients mess up the representation when using a joint encoder
class SoftQNetwork(nn.Module):
def __init__(self, envs):
super().__init__()
obs_shape = envs.single_observation_space.shape
self.conv = nn.Sequential(
layer_init(nn.Conv2d(obs_shape[0], 32, kernel_size=8, stride=4)),
nn.ReLU(),
layer_init(nn.Conv2d(32, 64, kernel_size=4, stride=2)),
nn.ReLU(),
layer_init(nn.Conv2d(64, 64, kernel_size=3, stride=1)),
nn.Flatten(),
)

with torch.inference_mode():
output_dim = self.conv(torch.zeros(1, *obs_shape)).shape[1]

self.fc1 = layer_init(nn.Linear(output_dim, 512))
self.fc_q = layer_init(nn.Linear(512, envs.single_action_space.n))

def forward(self, x):
x = F.relu(self.conv(x / 255.0))
x = F.relu(self.fc1(x))
q_vals = self.fc_q(x)
return q_vals


class Actor(nn.Module):
def __init__(self, envs):
super().__init__()
obs_shape = envs.single_observation_space.shape
self.conv = nn.Sequential(
layer_init(nn.Conv2d(obs_shape[0], 32, kernel_size=8, stride=4)),
nn.ReLU(),
layer_init(nn.Conv2d(32, 64, kernel_size=4, stride=2)),
nn.ReLU(),
layer_init(nn.Conv2d(64, 64, kernel_size=3, stride=1)),
nn.Flatten(),
)

with torch.inference_mode():
output_dim = self.conv(torch.zeros(1, *obs_shape)).shape[1]

self.fc1 = layer_init(nn.Linear(output_dim, 512))
self.fc_logits = layer_init(nn.Linear(512, envs.single_action_space.n))

def forward(self, x):
x = F.relu(self.conv(x))
x = F.relu(self.fc1(x))
logits = self.fc_logits(x)

return logits

def get_action(self, x):
logits = self(x / 255.0)
policy_dist = Categorical(logits=logits)
action = policy_dist.sample()
# Action probabilities for calculating the adapted soft-Q loss
action_probs = policy_dist.probs
log_prob = F.log_softmax(logits, dim=1)
return action, log_prob, action_probs


if __name__ == "__main__":
args = parse_args()
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
if args.track:
import wandb

wandb.init(
project=args.wandb_project_name,
entity=args.wandb_entity,
sync_tensorboard=True,
config=vars(args),
name=run_name,
monitor_gym=True,
save_code=True,
)
writer = SummaryWriter(f"runs/{run_name}")
writer.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)

# TRY NOT TO MODIFY: seeding
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic

device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

# env setup
envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)])
assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"

actor = Actor(envs).to(device)
qf1 = SoftQNetwork(envs).to(device)
qf2 = SoftQNetwork(envs).to(device)
qf1_target = SoftQNetwork(envs).to(device)
qf2_target = SoftQNetwork(envs).to(device)
qf1_target.load_state_dict(qf1.state_dict())
qf2_target.load_state_dict(qf2.state_dict())
# TRY NOT TO MODIFY: eps=1e-4 increases numerical stability
q_optimizer = optim.Adam(list(qf1.parameters()) + list(qf2.parameters()), lr=args.q_lr, eps=1e-4)
actor_optimizer = optim.Adam(list(actor.parameters()), lr=args.policy_lr, eps=1e-4)

# Automatic entropy tuning
if args.autotune:
target_entropy = -args.target_entropy_scale * torch.log(1 / torch.tensor(envs.single_action_space.n))
log_alpha = torch.zeros(1, requires_grad=True, device=device)
alpha = log_alpha.exp().item()
a_optimizer = optim.Adam([log_alpha], lr=args.q_lr, eps=1e-4)
else:
alpha = args.alpha

rb = ReplayBuffer(
args.buffer_size,
envs.single_observation_space,
envs.single_action_space,
device,
handle_timeout_termination=True,
)
start_time = time.time()

# TRY NOT TO MODIFY: start the game
obs = envs.reset()
for global_step in range(args.total_timesteps):
# ALGO LOGIC: put action logic here
if global_step < args.learning_starts:
actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
else:
actions, _, _ = actor.get_action(torch.Tensor(obs).to(device))
actions = actions.detach().cpu().numpy()

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, dones, infos = envs.step(actions)

# TRY NOT TO MODIFY: record rewards for plotting purposes
for info in infos:
if "episode" in info.keys():
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
break

# TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation`
real_next_obs = next_obs.copy()
for idx, d in enumerate(dones):
if d:
real_next_obs[idx] = infos[idx]["terminal_observation"]
rb.add(obs, real_next_obs, actions, rewards, dones, infos)

# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
obs = next_obs

# ALGO LOGIC: training.
if global_step > args.learning_starts:
if global_step % args.update_frequency == 0:
data = rb.sample(args.batch_size)
# CRITIC training
with torch.no_grad():
_, next_state_log_pi, next_state_action_probs = actor.get_action(data.next_observations)
qf1_next_target = qf1_target(data.next_observations)
qf2_next_target = qf2_target(data.next_observations)
# we can use the action probabilities instead of MC sampling to estimate the expectation
min_qf_next_target = next_state_action_probs * (
torch.min(qf1_next_target, qf2_next_target) - alpha * next_state_log_pi
)
# adapt Q-target for discrete Q-function
min_qf_next_target = min_qf_next_target.sum(dim=1)
next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * args.gamma * (min_qf_next_target)

# use Q-values only for the taken actions
qf1_values = qf1(data.observations)
qf2_values = qf2(data.observations)
qf1_a_values = qf1_values.gather(1, data.actions.long()).view(-1)
qf2_a_values = qf2_values.gather(1, data.actions.long()).view(-1)
qf1_loss = F.mse_loss(qf1_a_values, next_q_value)
qf2_loss = F.mse_loss(qf2_a_values, next_q_value)
qf_loss = qf1_loss + qf2_loss

q_optimizer.zero_grad()
qf_loss.backward()
q_optimizer.step()

# ACTOR training
_, log_pi, action_probs = actor.get_action(data.observations)
with torch.no_grad():
qf1_values = qf1(data.observations)
qf2_values = qf2(data.observations)
min_qf_values = torch.min(qf1_values, qf2_values)
# no need for reparameterization, the expectation can be calculated for discrete actions
actor_loss = (action_probs * ((alpha * log_pi) - min_qf_values)).mean()

actor_optimizer.zero_grad()
actor_loss.backward()
actor_optimizer.step()

if args.autotune:
# re-use action probabilities for temperature loss
alpha_loss = (action_probs.detach() * (-log_alpha * (log_pi + target_entropy).detach())).mean()

a_optimizer.zero_grad()
alpha_loss.backward()
a_optimizer.step()
alpha = log_alpha.exp().item()

# update the target networks
if global_step % args.target_network_frequency == 0:
for param, target_param in zip(qf1.parameters(), qf1_target.parameters()):
target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)
for param, target_param in zip(qf2.parameters(), qf2_target.parameters()):
target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)

if global_step % 100 == 0:
writer.add_scalar("losses/qf1_values", qf1_a_values.mean().item(), global_step)
writer.add_scalar("losses/qf2_values", qf2_a_values.mean().item(), global_step)
writer.add_scalar("losses/qf1_loss", qf1_loss.item(), global_step)
writer.add_scalar("losses/qf2_loss", qf2_loss.item(), global_step)
writer.add_scalar("losses/qf_loss", qf_loss.item() / 2.0, global_step)
writer.add_scalar("losses/actor_loss", actor_loss.item(), global_step)
writer.add_scalar("losses/alpha", alpha, global_step)
print("SPS:", int(global_step / (time.time() - start_time)))
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
if args.autotune:
writer.add_scalar("losses/alpha_loss", alpha_loss.item(), global_step)

envs.close()
writer.close()
Loading

1 comment on commit c3fc57d

@vercel
Copy link

@vercel vercel bot commented on c3fc57d Jan 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Successfully deployed to the following URLs:

cleanrl – ./

cleanrl.vercel.app
docs.cleanrl.dev
cleanrl-vwxyzjn.vercel.app
cleanrl-git-master-vwxyzjn.vercel.app

Please sign in to comment.