-
Notifications
You must be signed in to change notification settings - Fork 645
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add SAC-discrete implementation and docs (#270)
* add draft of SAC discrete implementation * run pre-commit * Use log softmax instead of author's log-pi code * Revert to cleanrl SAC delay implementation (it's more stable) * Remove docstrings and duplicate code * Use correct clipreward wrapper * fix bug in log softmax calculation * adhere to cleanrl log_prob naming * fix bug in entropy target calculation * change layer initialization to match existing cleanrl codebase * working minimal diff version * implement original learning update frequency * parameterize the entropy scale for autotuning * add benchmarking script * rename target entropy factor and set new default value * add docs draft * fix SAC-discrete links to work pre merge * add preliminary result table for SAC-discrete * clean up todos and add header * minimize diff between sac_atari and sac_continuous * add sac-discrete end2end test * SAC-discrete docs rework * Update SAC-discrete @100k results * Fix doc links and unify naming in code * update docs * fix target update frequency (see PR #323) * clarify comment regarding CNN encoder sharing * fix benchmark installation * fix eps in minimal diff version and improve code readability * add docs for eps and finalize code * use no_grad for actor Q-vals and re-use action-probs & log-probs in alpha loss * update docs for new code and settings * fix links to point to main branch * update sac-discrete training plots * new sac-d training plots * update results table and fix link * fix pong chart title * add Jimmy Ba name as exception to code spell check * change target_entropy_scale default value to same value as experiments * remove blank line at end of pre-commit Co-authored-by: Costa Huang <[email protected]>
- Loading branch information
Showing
11 changed files
with
558 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
poetry install --with atari | ||
OMP_NUM_THREADS=1 python -m cleanrl_utils.benchmark \ | ||
--env-ids PongNoFrameskip-v4 BreakoutNoFrameskip-v4 BeamRiderNoFrameskip-v4 \ | ||
--command "poetry run python cleanrl/sac_atari.py --cuda True --track" \ | ||
--num-seeds 3 \ | ||
--workers 2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,342 @@ | ||
# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/sac/#sac_ataripy | ||
import argparse | ||
import os | ||
import random | ||
import time | ||
from distutils.util import strtobool | ||
|
||
import gym | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.optim as optim | ||
from stable_baselines3.common.atari_wrappers import ( | ||
ClipRewardEnv, | ||
EpisodicLifeEnv, | ||
FireResetEnv, | ||
MaxAndSkipEnv, | ||
NoopResetEnv, | ||
) | ||
from stable_baselines3.common.buffers import ReplayBuffer | ||
from torch.distributions.categorical import Categorical | ||
from torch.utils.tensorboard import SummaryWriter | ||
|
||
|
||
def parse_args(): | ||
# fmt: off | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"), | ||
help="the name of this experiment") | ||
parser.add_argument("--seed", type=int, default=1, | ||
help="seed of the experiment") | ||
parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, | ||
help="if toggled, `torch.backends.cudnn.deterministic=False`") | ||
parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, | ||
help="if toggled, cuda will be enabled by default") | ||
parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, | ||
help="if toggled, this experiment will be tracked with Weights and Biases") | ||
parser.add_argument("--wandb-project-name", type=str, default="cleanRL", | ||
help="the wandb's project name") | ||
parser.add_argument("--wandb-entity", type=str, default=None, | ||
help="the entity (team) of wandb's project") | ||
parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, | ||
help="weather to capture videos of the agent performances (check out `videos` folder)") | ||
|
||
# Algorithm specific arguments | ||
parser.add_argument("--env-id", type=str, default="BeamRiderNoFrameskip-v4", | ||
help="the id of the environment") | ||
parser.add_argument("--total-timesteps", type=int, default=5000000, | ||
help="total timesteps of the experiments") | ||
parser.add_argument("--buffer-size", type=int, default=int(1e6), | ||
help="the replay memory buffer size") # smaller than in original paper but evaluation is done only for 100k steps anyway | ||
parser.add_argument("--gamma", type=float, default=0.99, | ||
help="the discount factor gamma") | ||
parser.add_argument("--tau", type=float, default=1.0, | ||
help="target smoothing coefficient (default: 1)") # Default is 1 to perform replacement update | ||
parser.add_argument("--batch-size", type=int, default=64, | ||
help="the batch size of sample from the reply memory") | ||
parser.add_argument("--learning-starts", type=int, default=2e4, | ||
help="timestep to start learning") | ||
parser.add_argument("--policy-lr", type=float, default=3e-4, | ||
help="the learning rate of the policy network optimizer") | ||
parser.add_argument("--q-lr", type=float, default=3e-4, | ||
help="the learning rate of the Q network network optimizer") | ||
parser.add_argument("--update-frequency", type=int, default=4, | ||
help="the frequency of training updates") | ||
parser.add_argument("--target-network-frequency", type=int, default=8000, | ||
help="the frequency of updates for the target networks") | ||
parser.add_argument("--alpha", type=float, default=0.2, | ||
help="Entropy regularization coefficient.") | ||
parser.add_argument("--autotune", type=lambda x:bool(strtobool(x)), default=True, nargs="?", const=True, | ||
help="automatic tuning of the entropy coefficient") | ||
parser.add_argument("--target-entropy-scale", type=float, default=0.89, | ||
help="coefficient for scaling the autotune entropy target") | ||
args = parser.parse_args() | ||
# fmt: on | ||
return args | ||
|
||
|
||
def make_env(env_id, seed, idx, capture_video, run_name): | ||
def thunk(): | ||
env = gym.make(env_id) | ||
env = gym.wrappers.RecordEpisodeStatistics(env) | ||
if capture_video: | ||
if idx == 0: | ||
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") | ||
env = NoopResetEnv(env, noop_max=30) | ||
env = MaxAndSkipEnv(env, skip=4) | ||
env = EpisodicLifeEnv(env) | ||
if "FIRE" in env.unwrapped.get_action_meanings(): | ||
env = FireResetEnv(env) | ||
env = ClipRewardEnv(env) | ||
env = gym.wrappers.ResizeObservation(env, (84, 84)) | ||
env = gym.wrappers.GrayScaleObservation(env) | ||
env = gym.wrappers.FrameStack(env, 4) | ||
env.seed(seed) | ||
env.action_space.seed(seed) | ||
env.observation_space.seed(seed) | ||
return env | ||
|
||
return thunk | ||
|
||
|
||
def layer_init(layer, bias_const=0.0): | ||
nn.init.kaiming_normal_(layer.weight) | ||
torch.nn.init.constant_(layer.bias, bias_const) | ||
return layer | ||
|
||
|
||
# ALGO LOGIC: initialize agent here: | ||
# NOTE: Sharing a CNN encoder between Actor and Critics is not recommended for SAC without stopping actor gradients | ||
# See the SAC+AE paper https://arxiv.org/abs/1910.01741 for more info | ||
# TL;DR The actor's gradients mess up the representation when using a joint encoder | ||
class SoftQNetwork(nn.Module): | ||
def __init__(self, envs): | ||
super().__init__() | ||
obs_shape = envs.single_observation_space.shape | ||
self.conv = nn.Sequential( | ||
layer_init(nn.Conv2d(obs_shape[0], 32, kernel_size=8, stride=4)), | ||
nn.ReLU(), | ||
layer_init(nn.Conv2d(32, 64, kernel_size=4, stride=2)), | ||
nn.ReLU(), | ||
layer_init(nn.Conv2d(64, 64, kernel_size=3, stride=1)), | ||
nn.Flatten(), | ||
) | ||
|
||
with torch.inference_mode(): | ||
output_dim = self.conv(torch.zeros(1, *obs_shape)).shape[1] | ||
|
||
self.fc1 = layer_init(nn.Linear(output_dim, 512)) | ||
self.fc_q = layer_init(nn.Linear(512, envs.single_action_space.n)) | ||
|
||
def forward(self, x): | ||
x = F.relu(self.conv(x / 255.0)) | ||
x = F.relu(self.fc1(x)) | ||
q_vals = self.fc_q(x) | ||
return q_vals | ||
|
||
|
||
class Actor(nn.Module): | ||
def __init__(self, envs): | ||
super().__init__() | ||
obs_shape = envs.single_observation_space.shape | ||
self.conv = nn.Sequential( | ||
layer_init(nn.Conv2d(obs_shape[0], 32, kernel_size=8, stride=4)), | ||
nn.ReLU(), | ||
layer_init(nn.Conv2d(32, 64, kernel_size=4, stride=2)), | ||
nn.ReLU(), | ||
layer_init(nn.Conv2d(64, 64, kernel_size=3, stride=1)), | ||
nn.Flatten(), | ||
) | ||
|
||
with torch.inference_mode(): | ||
output_dim = self.conv(torch.zeros(1, *obs_shape)).shape[1] | ||
|
||
self.fc1 = layer_init(nn.Linear(output_dim, 512)) | ||
self.fc_logits = layer_init(nn.Linear(512, envs.single_action_space.n)) | ||
|
||
def forward(self, x): | ||
x = F.relu(self.conv(x)) | ||
x = F.relu(self.fc1(x)) | ||
logits = self.fc_logits(x) | ||
|
||
return logits | ||
|
||
def get_action(self, x): | ||
logits = self(x / 255.0) | ||
policy_dist = Categorical(logits=logits) | ||
action = policy_dist.sample() | ||
# Action probabilities for calculating the adapted soft-Q loss | ||
action_probs = policy_dist.probs | ||
log_prob = F.log_softmax(logits, dim=1) | ||
return action, log_prob, action_probs | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parse_args() | ||
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" | ||
if args.track: | ||
import wandb | ||
|
||
wandb.init( | ||
project=args.wandb_project_name, | ||
entity=args.wandb_entity, | ||
sync_tensorboard=True, | ||
config=vars(args), | ||
name=run_name, | ||
monitor_gym=True, | ||
save_code=True, | ||
) | ||
writer = SummaryWriter(f"runs/{run_name}") | ||
writer.add_text( | ||
"hyperparameters", | ||
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), | ||
) | ||
|
||
# TRY NOT TO MODIFY: seeding | ||
random.seed(args.seed) | ||
np.random.seed(args.seed) | ||
torch.manual_seed(args.seed) | ||
torch.backends.cudnn.deterministic = args.torch_deterministic | ||
|
||
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") | ||
|
||
# env setup | ||
envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)]) | ||
assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" | ||
|
||
actor = Actor(envs).to(device) | ||
qf1 = SoftQNetwork(envs).to(device) | ||
qf2 = SoftQNetwork(envs).to(device) | ||
qf1_target = SoftQNetwork(envs).to(device) | ||
qf2_target = SoftQNetwork(envs).to(device) | ||
qf1_target.load_state_dict(qf1.state_dict()) | ||
qf2_target.load_state_dict(qf2.state_dict()) | ||
# TRY NOT TO MODIFY: eps=1e-4 increases numerical stability | ||
q_optimizer = optim.Adam(list(qf1.parameters()) + list(qf2.parameters()), lr=args.q_lr, eps=1e-4) | ||
actor_optimizer = optim.Adam(list(actor.parameters()), lr=args.policy_lr, eps=1e-4) | ||
|
||
# Automatic entropy tuning | ||
if args.autotune: | ||
target_entropy = -args.target_entropy_scale * torch.log(1 / torch.tensor(envs.single_action_space.n)) | ||
log_alpha = torch.zeros(1, requires_grad=True, device=device) | ||
alpha = log_alpha.exp().item() | ||
a_optimizer = optim.Adam([log_alpha], lr=args.q_lr, eps=1e-4) | ||
else: | ||
alpha = args.alpha | ||
|
||
rb = ReplayBuffer( | ||
args.buffer_size, | ||
envs.single_observation_space, | ||
envs.single_action_space, | ||
device, | ||
handle_timeout_termination=True, | ||
) | ||
start_time = time.time() | ||
|
||
# TRY NOT TO MODIFY: start the game | ||
obs = envs.reset() | ||
for global_step in range(args.total_timesteps): | ||
# ALGO LOGIC: put action logic here | ||
if global_step < args.learning_starts: | ||
actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) | ||
else: | ||
actions, _, _ = actor.get_action(torch.Tensor(obs).to(device)) | ||
actions = actions.detach().cpu().numpy() | ||
|
||
# TRY NOT TO MODIFY: execute the game and log data. | ||
next_obs, rewards, dones, infos = envs.step(actions) | ||
|
||
# TRY NOT TO MODIFY: record rewards for plotting purposes | ||
for info in infos: | ||
if "episode" in info.keys(): | ||
print(f"global_step={global_step}, episodic_return={info['episode']['r']}") | ||
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) | ||
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) | ||
break | ||
|
||
# TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation` | ||
real_next_obs = next_obs.copy() | ||
for idx, d in enumerate(dones): | ||
if d: | ||
real_next_obs[idx] = infos[idx]["terminal_observation"] | ||
rb.add(obs, real_next_obs, actions, rewards, dones, infos) | ||
|
||
# TRY NOT TO MODIFY: CRUCIAL step easy to overlook | ||
obs = next_obs | ||
|
||
# ALGO LOGIC: training. | ||
if global_step > args.learning_starts: | ||
if global_step % args.update_frequency == 0: | ||
data = rb.sample(args.batch_size) | ||
# CRITIC training | ||
with torch.no_grad(): | ||
_, next_state_log_pi, next_state_action_probs = actor.get_action(data.next_observations) | ||
qf1_next_target = qf1_target(data.next_observations) | ||
qf2_next_target = qf2_target(data.next_observations) | ||
# we can use the action probabilities instead of MC sampling to estimate the expectation | ||
min_qf_next_target = next_state_action_probs * ( | ||
torch.min(qf1_next_target, qf2_next_target) - alpha * next_state_log_pi | ||
) | ||
# adapt Q-target for discrete Q-function | ||
min_qf_next_target = min_qf_next_target.sum(dim=1) | ||
next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * args.gamma * (min_qf_next_target) | ||
|
||
# use Q-values only for the taken actions | ||
qf1_values = qf1(data.observations) | ||
qf2_values = qf2(data.observations) | ||
qf1_a_values = qf1_values.gather(1, data.actions.long()).view(-1) | ||
qf2_a_values = qf2_values.gather(1, data.actions.long()).view(-1) | ||
qf1_loss = F.mse_loss(qf1_a_values, next_q_value) | ||
qf2_loss = F.mse_loss(qf2_a_values, next_q_value) | ||
qf_loss = qf1_loss + qf2_loss | ||
|
||
q_optimizer.zero_grad() | ||
qf_loss.backward() | ||
q_optimizer.step() | ||
|
||
# ACTOR training | ||
_, log_pi, action_probs = actor.get_action(data.observations) | ||
with torch.no_grad(): | ||
qf1_values = qf1(data.observations) | ||
qf2_values = qf2(data.observations) | ||
min_qf_values = torch.min(qf1_values, qf2_values) | ||
# no need for reparameterization, the expectation can be calculated for discrete actions | ||
actor_loss = (action_probs * ((alpha * log_pi) - min_qf_values)).mean() | ||
|
||
actor_optimizer.zero_grad() | ||
actor_loss.backward() | ||
actor_optimizer.step() | ||
|
||
if args.autotune: | ||
# re-use action probabilities for temperature loss | ||
alpha_loss = (action_probs.detach() * (-log_alpha * (log_pi + target_entropy).detach())).mean() | ||
|
||
a_optimizer.zero_grad() | ||
alpha_loss.backward() | ||
a_optimizer.step() | ||
alpha = log_alpha.exp().item() | ||
|
||
# update the target networks | ||
if global_step % args.target_network_frequency == 0: | ||
for param, target_param in zip(qf1.parameters(), qf1_target.parameters()): | ||
target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data) | ||
for param, target_param in zip(qf2.parameters(), qf2_target.parameters()): | ||
target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data) | ||
|
||
if global_step % 100 == 0: | ||
writer.add_scalar("losses/qf1_values", qf1_a_values.mean().item(), global_step) | ||
writer.add_scalar("losses/qf2_values", qf2_a_values.mean().item(), global_step) | ||
writer.add_scalar("losses/qf1_loss", qf1_loss.item(), global_step) | ||
writer.add_scalar("losses/qf2_loss", qf2_loss.item(), global_step) | ||
writer.add_scalar("losses/qf_loss", qf_loss.item() / 2.0, global_step) | ||
writer.add_scalar("losses/actor_loss", actor_loss.item(), global_step) | ||
writer.add_scalar("losses/alpha", alpha, global_step) | ||
print("SPS:", int(global_step / (time.time() - start_time))) | ||
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) | ||
if args.autotune: | ||
writer.add_scalar("losses/alpha_loss", alpha_loss.item(), global_step) | ||
|
||
envs.close() | ||
writer.close() |
Oops, something went wrong.
c3fc57d
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Successfully deployed to the following URLs:
cleanrl – ./
cleanrl.vercel.app
docs.cleanrl.dev
cleanrl-vwxyzjn.vercel.app
cleanrl-git-master-vwxyzjn.vercel.app