Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SAC-discrete implementation #270

Merged
merged 42 commits into from
Jan 13, 2023
Merged
Changes from 6 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
18b643b
add draft of SAC discrete implementation
timoklein Aug 29, 2022
c3c98bd
run pre-commit
timoklein Aug 29, 2022
ec31dc4
Use log softmax instead of author's log-pi code
timoklein Aug 31, 2022
deb37e8
Revert to cleanrl SAC delay implementation (it's more stable)
timoklein Aug 31, 2022
a1fdd2b
Remove docstrings and duplicate code
timoklein Aug 31, 2022
977a83a
Use correct clipreward wrapper
timoklein Aug 31, 2022
f2ea3e6
fix bug in log softmax calculation
timoklein Sep 6, 2022
48af04c
adhere to cleanrl log_prob naming
timoklein Sep 6, 2022
b2a09a0
fix bug in entropy target calculation
timoklein Sep 6, 2022
89680c7
change layer initialization to match existing cleanrl codebase
timoklein Sep 6, 2022
b1d7d44
working minimal diff version
timoklein Sep 19, 2022
61e1c74
implement original learning update frequency
timoklein Sep 20, 2022
7cd1e3a
parameterize the entropy scale for autotuning
timoklein Oct 3, 2022
61c46fc
add benchmarking script
timoklein Oct 3, 2022
4915e4c
rename target entropy factor and set new default value
timoklein Oct 6, 2022
6f7251f
add docs draft
timoklein Nov 5, 2022
23b60ff
fix SAC-discrete links to work pre merge
timoklein Nov 10, 2022
10ee9f0
add preliminary result table for SAC-discrete
timoklein Nov 10, 2022
8430fd8
clean up todos and add header
timoklein Nov 10, 2022
a17768c
minimize diff between sac_atari and sac_continuous
timoklein Nov 11, 2022
d6a507c
add sac-discrete end2end test
timoklein Nov 11, 2022
a7ea6f4
SAC-discrete docs rework
timoklein Nov 11, 2022
9f6493c
Update SAC-discrete @100k results
timoklein Nov 12, 2022
59a6d00
Fix doc links and unify naming in code
timoklein Nov 12, 2022
1304b7a
update docs
vwxyzjn Nov 13, 2022
3a3f41b
fix target update frequency (see PR #323)
timoklein Nov 24, 2022
80187ad
clarify comment regarding CNN encoder sharing
timoklein Nov 24, 2022
e9cb494
Merge remote-tracking branch 'upstream/master' into sac-discrete
timoklein Nov 25, 2022
e199e39
fix benchmark installation
timoklein Nov 25, 2022
bb27fa1
fix eps in minimal diff version and improve code readability
timoklein Dec 3, 2022
6a46632
add docs for eps and finalize code
timoklein Dec 5, 2022
cad5fff
use no_grad for actor Q-vals and re-use action-probs & log-probs in a…
timoklein Dec 7, 2022
0cf47f1
update docs for new code and settings
timoklein Dec 14, 2022
61988c4
fix links to point to main branch
timoklein Dec 14, 2022
6e17005
update sac-discrete training plots
timoklein Dec 19, 2022
33b00f3
new sac-d training plots
timoklein Dec 19, 2022
5dabafb
update results table and fix link
timoklein Dec 19, 2022
90b2fd5
fix pong chart title
timoklein Dec 19, 2022
a763994
add Jimmy Ba name as exception to code spell check
timoklein Jan 13, 2023
071cdbb
change target_entropy_scale default value to same value as experiments
timoklein Jan 13, 2023
dcc2633
Merge remote-tracking branch 'upstream/master' into sac-discrete
timoklein Jan 13, 2023
c671a92
remove blank line at end of pre-commit
timoklein Jan 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
359 changes: 359 additions & 0 deletions cleanrl/sac_atari.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,359 @@
# TODO: Must add header here
import argparse
import os
import random
import time
from distutils.util import strtobool

import gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from stable_baselines3.common.atari_wrappers import (
ClipRewardEnv,
EpisodicLifeEnv,
FireResetEnv,
MaxAndSkipEnv,
NoopResetEnv,
)
from stable_baselines3.common.buffers import ReplayBuffer
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter


def parse_args():
# fmt: off
parser = argparse.ArgumentParser()
parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"),
help="the name of this experiment")
parser.add_argument("--seed", type=int, default=1,
help="seed of the experiment")
parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="if toggled, `torch.backends.cudnn.deterministic=False`")
parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="if toggled, cuda will be enabled by default")
parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="if toggled, this experiment will be tracked with Weights and Biases")
# TODO: Must change this back to CLeanRL later
parser.add_argument("--wandb-project-name", type=str, default="SAC-discrete",
help="the wandb's project name")
parser.add_argument("--wandb-entity", type=str, default=None,
help="the entity (team) of wandb's project")
parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="weather to capture videos of the agent performances (check out `videos` folder)")

# Algorithm specific arguments
parser.add_argument("--env-id", type=str, default="Seaquest-v4",
help="the id of the environment")
parser.add_argument("--total-timesteps", type=int, default=120000,
help="total timesteps of the experiments")
parser.add_argument("--buffer-size", type=int, default=int(1.5e5),
help="the replay memory buffer size") # smaller than in original paper but evaluation is done only for 100k steps anyway
parser.add_argument("--gamma", type=float, default=0.99,
help="the discount factor gamma")
parser.add_argument("--tau", type=float, default=1.0,
help="target smoothing coefficient (default: 1)") # Default is 1 to perform replacement update
parser.add_argument("--batch-size", type=int, default=64,
help="the batch size of sample from the reply memory")
parser.add_argument("--learning-starts", type=int, default=2e4,
help="timestep to start learning")
parser.add_argument("--policy-lr", type=float, default=3e-4,
help="the learning rate of the policy network optimizer")
parser.add_argument("--q-lr", type=float, default=3e-4,
help="the learning rate of the Q network network optimizer")
parser.add_argument("--policy-frequency", type=int, default=4,
help="the frequency of training policy (delayed)")
parser.add_argument("--target-network-frequency", type=int, default=8000, # Denis Yarats' implementation delays this by 2.
help="the frequency of updates for the target networks")
parser.add_argument("--alpha", type=float, default=0.2,
help="Entropy regularization coefficient.")
parser.add_argument("--autotune", type=lambda x:bool(strtobool(x)), default=True, nargs="?", const=True,
help="automatic tuning of the entropy coefficient")
args = parser.parse_args()
# fmt: on
return args


def make_env(env_id, seed, idx, capture_video, run_name):
def thunk():
env = gym.make(env_id)
env = gym.wrappers.RecordEpisodeStatistics(env)
if capture_video:
if idx == 0:
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
env = NoopResetEnv(env, noop_max=30)
env = MaxAndSkipEnv(env, skip=4)
env = EpisodicLifeEnv(env)
if "FIRE" in env.unwrapped.get_action_meanings():
env = FireResetEnv(env)
env = ClipRewardEnv(env)
env = gym.wrappers.ResizeObservation(env, (84, 84))
env = gym.wrappers.GrayScaleObservation(env)
env = gym.wrappers.FrameStack(env, 4)
env.seed(seed)
env.action_space.seed(seed)
env.observation_space.seed(seed)
return env
timoklein marked this conversation as resolved.
Show resolved Hide resolved

return thunk


# He initialization
timoklein marked this conversation as resolved.
Show resolved Hide resolved
def network_init(net, bias_const=0.0):
dosssman marked this conversation as resolved.
Show resolved Hide resolved
for param in net.parameters():
if type(param) == nn.Linear or type(param) == nn.Conv2d:
dosssman marked this conversation as resolved.
Show resolved Hide resolved
nn.init.kaiming_normal_(param.weight)
torch.nn.init.constant_(param.bias, bias_const)

return net


# ALGO LOGIC: initialize agent here:
# NOTE: Sharing a CNN encoder between Actor and Critics is not recommended for SAC
dosssman marked this conversation as resolved.
Show resolved Hide resolved
# See https://arxiv.org/abs/1910.01741 for more info
# TL;DR The actor's gradients mess up the representation when using a joint encoder
class SoftQNetwork(nn.Module):
def __init__(self, envs):
super().__init__()
obs_shape = envs.single_observation_space.shape
self.conv = nn.Sequential(
nn.Conv2d(obs_shape[0], 32, kernel_size=8, stride=4),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1),
nn.Flatten(),
nn.ReLU(),
)

with torch.inference_mode():
output_dim = self.conv(torch.zeros(1, *obs_shape)).shape[1]

self.Q_func = nn.Sequential(
nn.Linear(output_dim, 512),
nn.ReLU(),
nn.Linear(512, envs.single_action_space.n),
)

def forward(self, x):
x = self.conv(x / 255.0)
x = self.Q_func(x)
return x


class Actor(nn.Module):
def __init__(self, envs):
super().__init__()
obs_shape = envs.single_observation_space.shape
self.conv = nn.Sequential(
nn.Conv2d(obs_shape[0], 32, kernel_size=8, stride=4),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1),
nn.Flatten(),
nn.ReLU(),
)

with torch.inference_mode():
output_dim = self.conv(torch.zeros(1, *obs_shape)).shape[1]

self.fc_logits = nn.Sequential(
nn.Linear(output_dim, 512),
nn.ReLU(),
nn.Linear(512, envs.single_action_space.n),
)

def forward(self, x):
x = self.conv(x)
logits = self.fc_logits(x)

return logits
timoklein marked this conversation as resolved.
Show resolved Hide resolved

def get_action(self, x):
logits = self(x / 255.0)
policy_dist = Categorical(logits=logits)
action = policy_dist.sample()
# Action probabilities for calculating the adapted soft-Q loss
action_probs = policy_dist.probs
log_pi = F.log_softmax(action_probs, dim=1)
timoklein marked this conversation as resolved.
Show resolved Hide resolved
return action, log_pi, action_probs
timoklein marked this conversation as resolved.
Show resolved Hide resolved


if __name__ == "__main__":
args = parse_args()
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
if args.track:
import wandb

wandb.init(
project=args.wandb_project_name,
entity=args.wandb_entity,
sync_tensorboard=True,
config=vars(args),
name=run_name,
monitor_gym=True,
save_code=True,
)
writer = SummaryWriter(f"runs/{run_name}")
writer.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)

# Don't utilize full CPU or else I might get warned
# TODO: Must change back before merge
torch.set_num_threads(1)

# TRY NOT TO MODIFY: seeding
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic

# TODO: Change back before merge
device = torch.device("cuda:0" if torch.cuda.is_available() and args.cuda else "cpu")

# env setup
envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)])
assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"

actor = network_init(Actor(envs)).to(device)
qf1 = network_init(SoftQNetwork(envs)).to(device)
qf2 = network_init(SoftQNetwork(envs)).to(device)

qf1_target = SoftQNetwork(envs).to(device)
qf2_target = SoftQNetwork(envs).to(device)
qf1_target.load_state_dict(qf1.state_dict())
qf2_target.load_state_dict(qf2.state_dict())
q_optimizer = optim.Adam(list(qf1.parameters()) + list(qf2.parameters()), lr=args.q_lr, eps=1e-4)
actor_optimizer = optim.Adam(list(actor.parameters()), lr=args.policy_lr, eps=1e-4)

# Automatic entropy tuning
if args.autotune:
target_entropy = -0.98 * torch.log(torch.tensor(envs.single_action_space.n, dtype=float))
timoklein marked this conversation as resolved.
Show resolved Hide resolved
log_alpha = torch.zeros(1, requires_grad=True, device=device)
alpha = log_alpha.exp().item()
a_optimizer = optim.Adam([log_alpha], lr=args.policy_lr, eps=1e-4)
else:
alpha = args.alpha

rb = ReplayBuffer(
args.buffer_size,
envs.single_observation_space,
envs.single_action_space,
device,
# TODO: Can be changed back for cleanRL runs
optimize_memory_usage=False,
handle_timeout_termination=True,
timoklein marked this conversation as resolved.
Show resolved Hide resolved
)
start_time = time.time()

# TRY NOT TO MODIFY: start the game
obs = envs.reset()
for global_step in range(args.total_timesteps):
# ALGO LOGIC: put action logic here
if global_step < args.learning_starts:
actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
else:
actions, _, _ = actor.get_action(torch.Tensor(obs).to(device))
actions = actions.detach().cpu().numpy()

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, dones, infos = envs.step(actions)

# TRY NOT TO MODIFY: record rewards for plotting purposes
for info in infos:
if "episode" in info.keys():
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
break

# TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation`
real_next_obs = next_obs.copy()
for idx, d in enumerate(dones):
if d:
real_next_obs[idx] = infos[idx]["terminal_observation"]
rb.add(obs, real_next_obs, actions, rewards, dones, [info])
timoklein marked this conversation as resolved.
Show resolved Hide resolved

# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
obs = next_obs

# ALGO LOGIC: training.
if global_step > args.learning_starts:
data = rb.sample(args.batch_size)
# CRITIC training
with torch.no_grad():
_, next_state_log_pi, next_state_action_probs = actor.get_action(data.next_observations)
qf1_next_target = qf1_target(data.next_observations)
qf2_next_target = qf2_target(data.next_observations)
# we can use the action probabilities instead of MC sampling to estimate the expectation
min_qf_next_target = next_state_action_probs * (
torch.min(qf1_next_target, qf2_next_target) - alpha * next_state_log_pi
)
# adapt Q-target for discrete Q-function
min_qf_next_target = min_qf_next_target.sum(dim=1)[..., None]
timoklein marked this conversation as resolved.
Show resolved Hide resolved
next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * args.gamma * (min_qf_next_target).view(-1)

# Use Q-values only for the taken actions
qf1_a_values = qf1(data.observations).gather(1, data.actions.long())
qf2_a_values = qf2(data.observations).gather(1, data.actions.long())
qf1_loss = F.mse_loss(qf1_a_values, next_q_value[..., None], reduction="mean")
qf2_loss = F.mse_loss(qf2_a_values, next_q_value[..., None], reduction="mean")
qf_loss = qf1_loss + qf2_loss

q_optimizer.zero_grad()
qf_loss.backward()
q_optimizer.step()

if global_step % args.policy_frequency == 0: # TD 3 Delayed update support
for _ in range(args.policy_frequency):
# ACTOR training
_, log_pi, pi_probs = actor.get_action(data.observations)
with torch.no_grad():
qf1_pi = qf1(data.observations)
qf2_pi = qf2(data.observations)
min_qf_pi = torch.min(qf1_pi, qf2_pi)
# no need for reparameterization, the expectation can be calculated for discrete actions
actor_loss = (pi_probs * ((alpha * log_pi) - min_qf_pi)).mean()

actor_optimizer.zero_grad()
actor_loss.backward()
actor_optimizer.step()

if args.autotune:
with torch.no_grad():
_, log_pi, action_probs = actor.get_action(data.observations)
# use action probabilities for temperate loss
alpha_loss = (action_probs * (-log_alpha * (log_pi + target_entropy))).mean()

a_optimizer.zero_grad()
alpha_loss.backward()
a_optimizer.step()
alpha = log_alpha.exp().item()

# update the target networks
if global_step % args.target_network_frequency == 0:
for param, target_param in zip(qf1.parameters(), qf1_target.parameters()):
target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)
for param, target_param in zip(qf2.parameters(), qf2_target.parameters()):
target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)

if global_step % 100 == 0:
writer.add_scalar("losses/qf1_values", qf1_a_values.mean().item(), global_step)
writer.add_scalar("losses/qf2_values", qf2_a_values.mean().item(), global_step)
writer.add_scalar("losses/qf1_loss", qf1_loss.item(), global_step)
writer.add_scalar("losses/qf2_loss", qf2_loss.item(), global_step)
writer.add_scalar("losses/qf_loss", qf_loss.item() / 2.0, global_step)
writer.add_scalar("losses/actor_loss", actor_loss.item(), global_step)
writer.add_scalar("losses/alpha", alpha, global_step)
print("SPS:", int(global_step / (time.time() - start_time)))
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
if args.autotune:
writer.add_scalar("losses/alpha_loss", alpha_loss.item(), global_step)

envs.close()
writer.close()