Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SAC-discrete implementation #270

Merged
merged 42 commits into from
Jan 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
18b643b
add draft of SAC discrete implementation
timoklein Aug 29, 2022
c3c98bd
run pre-commit
timoklein Aug 29, 2022
ec31dc4
Use log softmax instead of author's log-pi code
timoklein Aug 31, 2022
deb37e8
Revert to cleanrl SAC delay implementation (it's more stable)
timoklein Aug 31, 2022
a1fdd2b
Remove docstrings and duplicate code
timoklein Aug 31, 2022
977a83a
Use correct clipreward wrapper
timoklein Aug 31, 2022
f2ea3e6
fix bug in log softmax calculation
timoklein Sep 6, 2022
48af04c
adhere to cleanrl log_prob naming
timoklein Sep 6, 2022
b2a09a0
fix bug in entropy target calculation
timoklein Sep 6, 2022
89680c7
change layer initialization to match existing cleanrl codebase
timoklein Sep 6, 2022
b1d7d44
working minimal diff version
timoklein Sep 19, 2022
61e1c74
implement original learning update frequency
timoklein Sep 20, 2022
7cd1e3a
parameterize the entropy scale for autotuning
timoklein Oct 3, 2022
61c46fc
add benchmarking script
timoklein Oct 3, 2022
4915e4c
rename target entropy factor and set new default value
timoklein Oct 6, 2022
6f7251f
add docs draft
timoklein Nov 5, 2022
23b60ff
fix SAC-discrete links to work pre merge
timoklein Nov 10, 2022
10ee9f0
add preliminary result table for SAC-discrete
timoklein Nov 10, 2022
8430fd8
clean up todos and add header
timoklein Nov 10, 2022
a17768c
minimize diff between sac_atari and sac_continuous
timoklein Nov 11, 2022
d6a507c
add sac-discrete end2end test
timoklein Nov 11, 2022
a7ea6f4
SAC-discrete docs rework
timoklein Nov 11, 2022
9f6493c
Update SAC-discrete @100k results
timoklein Nov 12, 2022
59a6d00
Fix doc links and unify naming in code
timoklein Nov 12, 2022
1304b7a
update docs
vwxyzjn Nov 13, 2022
3a3f41b
fix target update frequency (see PR #323)
timoklein Nov 24, 2022
80187ad
clarify comment regarding CNN encoder sharing
timoklein Nov 24, 2022
e9cb494
Merge remote-tracking branch 'upstream/master' into sac-discrete
timoklein Nov 25, 2022
e199e39
fix benchmark installation
timoklein Nov 25, 2022
bb27fa1
fix eps in minimal diff version and improve code readability
timoklein Dec 3, 2022
6a46632
add docs for eps and finalize code
timoklein Dec 5, 2022
cad5fff
use no_grad for actor Q-vals and re-use action-probs & log-probs in a…
timoklein Dec 7, 2022
0cf47f1
update docs for new code and settings
timoklein Dec 14, 2022
61988c4
fix links to point to main branch
timoklein Dec 14, 2022
6e17005
update sac-discrete training plots
timoklein Dec 19, 2022
33b00f3
new sac-d training plots
timoklein Dec 19, 2022
5dabafb
update results table and fix link
timoklein Dec 19, 2022
90b2fd5
fix pong chart title
timoklein Dec 19, 2022
a763994
add Jimmy Ba name as exception to code spell check
timoklein Jan 13, 2023
071cdbb
change target_entropy_scale default value to same value as experiments
timoklein Jan 13, 2023
dcc2633
Merge remote-tracking branch 'upstream/master' into sac-discrete
timoklein Jan 13, 2023
c671a92
remove blank line at end of pre-commit
timoklein Jan 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ repos:
hooks:
- id: codespell
args:
- --ignore-words-list=nd,reacher,thist,ths,magent
- --ignore-words-list=nd,reacher,thist,ths,magent,Ba
- --skip=docs/css/termynal.css,docs/js/termynal.js,docs/get-started/CleanRL_Huggingface_Integration_Demo.ipynb
- repo: https://github.com/python-poetry/poetry
rev: 1.2.1
Expand Down Expand Up @@ -79,4 +79,4 @@ repos:
- id: poetry-export
name: poetry-export requirements-cloud.txt
args: ["--without-hashes", "-o", "requirements/requirements-cloud.txt", "--with", "cloud"]
stages: [manual]
stages: [manual]
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ poetry install --with atari
python cleanrl/dqn_atari.py --env-id BreakoutNoFrameskip-v4
python cleanrl/c51_atari.py --env-id BreakoutNoFrameskip-v4
python cleanrl/ppo_atari.py --env-id BreakoutNoFrameskip-v4
python cleanrl/sac_atari.py --env-id BreakoutNoFrameskip-v4

# NEW: 3-4x side-effects free speed up with envpool's atari (only available to linux)
poetry install --with envpool
Expand Down Expand Up @@ -129,14 +130,15 @@ You may also use a prebuilt development environment hosted in Gitpod:
| | [`ppo_pettingzoo_ma_atari.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_pettingzoo_ma_atari.py), [docs](https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_pettingzoo_ma_ataripy)
| | [`ppo_continuous_action_isaacgym.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_continuous_action_isaacgym/ppo_continuous_action_isaacgym.py), [docs](https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_continuous_action_isaacgympy)
| ✅ [Deep Q-Learning (DQN)](https://web.stanford.edu/class/psych209/Readings/MnihEtAlHassibis15NatureControlDeepRL.pdf) | [`dqn.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn.py), [docs](https://docs.cleanrl.dev/rl-algorithms/dqn/#dqnpy) |
| | [`dqn_atari.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn_atari.py), [docs](https://docs.cleanrl.dev/rl-algorithms/dqn/#dqn_ataripy) |
| | [`dqn_atari.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn_atari.py), [docs](https://docs.cleanrl.dev/rl-algorithms/dqn/#dqn_ataripy) |
| | [`dqn_jax.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn_jax.py), [docs](https://docs.cleanrl.dev/rl-algorithms/dqn/#dqn_jaxpy) |
| | [`dqn_atari_jax.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn_atari_jax.py), [docs](https://docs.cleanrl.dev/rl-algorithms/dqn/#dqn_atari_jaxpy) |
| ✅ [Categorical DQN (C51)](https://arxiv.org/pdf/1707.06887.pdf) | [`c51.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/c51.py), [docs](https://docs.cleanrl.dev/rl-algorithms/c51/#c51py) |
| | [`c51_atari.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/c51_atari.py), [docs](https://docs.cleanrl.dev/rl-algorithms/c51/#c51_ataripy) |
| | [`c51_jax.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/c51_jax.py), [docs](https://docs.cleanrl.dev/rl-algorithms/c51/#c51_jaxpy) |
| | [`c51_atari_jax.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/c51_atari_jax.py), [docs](https://docs.cleanrl.dev/rl-algorithms/c51/#c51_atari_jaxpy) |
| ✅ [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf) | [`sac_continuous_action.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/sac_continuous_action.py), [docs](https://docs.cleanrl.dev/rl-algorithms/sac/#sac_continuous_actionpy) |
| | [`sac_atari.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/sac_atari.py), [docs](https://docs.cleanrl.dev/rl-algorithms/sac/#sac_atarinpy) |
| ✅ [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf) | [`ddpg_continuous_action.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ddpg_continuous_action.py), [docs](https://docs.cleanrl.dev/rl-algorithms/ddpg/#ddpg_continuous_actionpy) |
| | [`ddpg_continuous_action_jax.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ddpg_continuous_action_jax.py), [docs](https://docs.cleanrl.dev/rl-algorithms/ddpg/#ddpg_continuous_action_jaxpy)
| ✅ [Twin Delayed Deep Deterministic Policy Gradient (TD3)](https://arxiv.org/pdf/1802.09477.pdf) | [`td3_continuous_action.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/td3_continuous_action.py), [docs](https://docs.cleanrl.dev/rl-algorithms/td3/#td3_continuous_actionpy) |
Expand Down
6 changes: 6 additions & 0 deletions benchmark/sac_atari.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
poetry install --with atari
OMP_NUM_THREADS=1 python -m cleanrl_utils.benchmark \
--env-ids PongNoFrameskip-v4 BreakoutNoFrameskip-v4 BeamRiderNoFrameskip-v4 \
--command "poetry run python cleanrl/sac_atari.py --cuda True --track" \
--num-seeds 3 \
--workers 2
342 changes: 342 additions & 0 deletions cleanrl/sac_atari.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,342 @@
# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/sac/#sac_ataripy
import argparse
import os
import random
import time
from distutils.util import strtobool

import gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from stable_baselines3.common.atari_wrappers import (
ClipRewardEnv,
EpisodicLifeEnv,
FireResetEnv,
MaxAndSkipEnv,
NoopResetEnv,
)
from stable_baselines3.common.buffers import ReplayBuffer
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter


def parse_args():
# fmt: off
parser = argparse.ArgumentParser()
parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"),
help="the name of this experiment")
parser.add_argument("--seed", type=int, default=1,
help="seed of the experiment")
parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="if toggled, `torch.backends.cudnn.deterministic=False`")
parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
help="if toggled, cuda will be enabled by default")
parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="if toggled, this experiment will be tracked with Weights and Biases")
parser.add_argument("--wandb-project-name", type=str, default="cleanRL",
help="the wandb's project name")
parser.add_argument("--wandb-entity", type=str, default=None,
help="the entity (team) of wandb's project")
parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="weather to capture videos of the agent performances (check out `videos` folder)")

# Algorithm specific arguments
parser.add_argument("--env-id", type=str, default="BeamRiderNoFrameskip-v4",
help="the id of the environment")
parser.add_argument("--total-timesteps", type=int, default=5000000,
help="total timesteps of the experiments")
parser.add_argument("--buffer-size", type=int, default=int(1e6),
help="the replay memory buffer size") # smaller than in original paper but evaluation is done only for 100k steps anyway
parser.add_argument("--gamma", type=float, default=0.99,
help="the discount factor gamma")
parser.add_argument("--tau", type=float, default=1.0,
help="target smoothing coefficient (default: 1)") # Default is 1 to perform replacement update
parser.add_argument("--batch-size", type=int, default=64,
help="the batch size of sample from the reply memory")
parser.add_argument("--learning-starts", type=int, default=2e4,
help="timestep to start learning")
parser.add_argument("--policy-lr", type=float, default=3e-4,
help="the learning rate of the policy network optimizer")
parser.add_argument("--q-lr", type=float, default=3e-4,
help="the learning rate of the Q network network optimizer")
parser.add_argument("--update-frequency", type=int, default=4,
help="the frequency of training updates")
parser.add_argument("--target-network-frequency", type=int, default=8000,
help="the frequency of updates for the target networks")
parser.add_argument("--alpha", type=float, default=0.2,
help="Entropy regularization coefficient.")
parser.add_argument("--autotune", type=lambda x:bool(strtobool(x)), default=True, nargs="?", const=True,
help="automatic tuning of the entropy coefficient")
parser.add_argument("--target-entropy-scale", type=float, default=0.89,
help="coefficient for scaling the autotune entropy target")
args = parser.parse_args()
# fmt: on
return args


def make_env(env_id, seed, idx, capture_video, run_name):
def thunk():
env = gym.make(env_id)
env = gym.wrappers.RecordEpisodeStatistics(env)
if capture_video:
if idx == 0:
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
env = NoopResetEnv(env, noop_max=30)
env = MaxAndSkipEnv(env, skip=4)
env = EpisodicLifeEnv(env)
if "FIRE" in env.unwrapped.get_action_meanings():
env = FireResetEnv(env)
env = ClipRewardEnv(env)
env = gym.wrappers.ResizeObservation(env, (84, 84))
env = gym.wrappers.GrayScaleObservation(env)
env = gym.wrappers.FrameStack(env, 4)
env.seed(seed)
env.action_space.seed(seed)
env.observation_space.seed(seed)
return env
timoklein marked this conversation as resolved.
Show resolved Hide resolved

return thunk


def layer_init(layer, bias_const=0.0):
nn.init.kaiming_normal_(layer.weight)
torch.nn.init.constant_(layer.bias, bias_const)
return layer


# ALGO LOGIC: initialize agent here:
# NOTE: Sharing a CNN encoder between Actor and Critics is not recommended for SAC without stopping actor gradients
# See the SAC+AE paper https://arxiv.org/abs/1910.01741 for more info
# TL;DR The actor's gradients mess up the representation when using a joint encoder
class SoftQNetwork(nn.Module):
def __init__(self, envs):
super().__init__()
obs_shape = envs.single_observation_space.shape
self.conv = nn.Sequential(
layer_init(nn.Conv2d(obs_shape[0], 32, kernel_size=8, stride=4)),
nn.ReLU(),
layer_init(nn.Conv2d(32, 64, kernel_size=4, stride=2)),
nn.ReLU(),
layer_init(nn.Conv2d(64, 64, kernel_size=3, stride=1)),
nn.Flatten(),
)

with torch.inference_mode():
output_dim = self.conv(torch.zeros(1, *obs_shape)).shape[1]

self.fc1 = layer_init(nn.Linear(output_dim, 512))
self.fc_q = layer_init(nn.Linear(512, envs.single_action_space.n))

def forward(self, x):
x = F.relu(self.conv(x / 255.0))
x = F.relu(self.fc1(x))
q_vals = self.fc_q(x)
return q_vals


class Actor(nn.Module):
def __init__(self, envs):
super().__init__()
obs_shape = envs.single_observation_space.shape
self.conv = nn.Sequential(
layer_init(nn.Conv2d(obs_shape[0], 32, kernel_size=8, stride=4)),
nn.ReLU(),
layer_init(nn.Conv2d(32, 64, kernel_size=4, stride=2)),
nn.ReLU(),
layer_init(nn.Conv2d(64, 64, kernel_size=3, stride=1)),
nn.Flatten(),
)

with torch.inference_mode():
output_dim = self.conv(torch.zeros(1, *obs_shape)).shape[1]

self.fc1 = layer_init(nn.Linear(output_dim, 512))
self.fc_logits = layer_init(nn.Linear(512, envs.single_action_space.n))

def forward(self, x):
x = F.relu(self.conv(x))
x = F.relu(self.fc1(x))
logits = self.fc_logits(x)

return logits

def get_action(self, x):
logits = self(x / 255.0)
policy_dist = Categorical(logits=logits)
action = policy_dist.sample()
# Action probabilities for calculating the adapted soft-Q loss
action_probs = policy_dist.probs
log_prob = F.log_softmax(logits, dim=1)
return action, log_prob, action_probs


if __name__ == "__main__":
args = parse_args()
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
if args.track:
import wandb

wandb.init(
project=args.wandb_project_name,
entity=args.wandb_entity,
sync_tensorboard=True,
config=vars(args),
name=run_name,
monitor_gym=True,
save_code=True,
)
writer = SummaryWriter(f"runs/{run_name}")
writer.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)

# TRY NOT TO MODIFY: seeding
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic

device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

# env setup
envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)])
assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"

actor = Actor(envs).to(device)
qf1 = SoftQNetwork(envs).to(device)
qf2 = SoftQNetwork(envs).to(device)
qf1_target = SoftQNetwork(envs).to(device)
qf2_target = SoftQNetwork(envs).to(device)
qf1_target.load_state_dict(qf1.state_dict())
qf2_target.load_state_dict(qf2.state_dict())
# TRY NOT TO MODIFY: eps=1e-4 increases numerical stability
q_optimizer = optim.Adam(list(qf1.parameters()) + list(qf2.parameters()), lr=args.q_lr, eps=1e-4)
actor_optimizer = optim.Adam(list(actor.parameters()), lr=args.policy_lr, eps=1e-4)

# Automatic entropy tuning
if args.autotune:
target_entropy = -args.target_entropy_scale * torch.log(1 / torch.tensor(envs.single_action_space.n))
timoklein marked this conversation as resolved.
Show resolved Hide resolved
log_alpha = torch.zeros(1, requires_grad=True, device=device)
alpha = log_alpha.exp().item()
a_optimizer = optim.Adam([log_alpha], lr=args.q_lr, eps=1e-4)
else:
alpha = args.alpha

rb = ReplayBuffer(
args.buffer_size,
envs.single_observation_space,
envs.single_action_space,
device,
handle_timeout_termination=True,
timoklein marked this conversation as resolved.
Show resolved Hide resolved
)
start_time = time.time()

# TRY NOT TO MODIFY: start the game
obs = envs.reset()
for global_step in range(args.total_timesteps):
# ALGO LOGIC: put action logic here
if global_step < args.learning_starts:
actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
else:
actions, _, _ = actor.get_action(torch.Tensor(obs).to(device))
actions = actions.detach().cpu().numpy()

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, dones, infos = envs.step(actions)

# TRY NOT TO MODIFY: record rewards for plotting purposes
for info in infos:
if "episode" in info.keys():
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
break

# TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation`
real_next_obs = next_obs.copy()
for idx, d in enumerate(dones):
if d:
real_next_obs[idx] = infos[idx]["terminal_observation"]
rb.add(obs, real_next_obs, actions, rewards, dones, infos)

# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
obs = next_obs

# ALGO LOGIC: training.
if global_step > args.learning_starts:
if global_step % args.update_frequency == 0:
data = rb.sample(args.batch_size)
# CRITIC training
with torch.no_grad():
_, next_state_log_pi, next_state_action_probs = actor.get_action(data.next_observations)
qf1_next_target = qf1_target(data.next_observations)
qf2_next_target = qf2_target(data.next_observations)
# we can use the action probabilities instead of MC sampling to estimate the expectation
min_qf_next_target = next_state_action_probs * (
timoklein marked this conversation as resolved.
Show resolved Hide resolved
torch.min(qf1_next_target, qf2_next_target) - alpha * next_state_log_pi
)
# adapt Q-target for discrete Q-function
min_qf_next_target = min_qf_next_target.sum(dim=1)
next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * args.gamma * (min_qf_next_target)

# use Q-values only for the taken actions
qf1_values = qf1(data.observations)
qf2_values = qf2(data.observations)
qf1_a_values = qf1_values.gather(1, data.actions.long()).view(-1)
qf2_a_values = qf2_values.gather(1, data.actions.long()).view(-1)
qf1_loss = F.mse_loss(qf1_a_values, next_q_value)
qf2_loss = F.mse_loss(qf2_a_values, next_q_value)
qf_loss = qf1_loss + qf2_loss

q_optimizer.zero_grad()
qf_loss.backward()
q_optimizer.step()

# ACTOR training
_, log_pi, action_probs = actor.get_action(data.observations)
with torch.no_grad():
qf1_values = qf1(data.observations)
qf2_values = qf2(data.observations)
min_qf_values = torch.min(qf1_values, qf2_values)
# no need for reparameterization, the expectation can be calculated for discrete actions
actor_loss = (action_probs * ((alpha * log_pi) - min_qf_values)).mean()

actor_optimizer.zero_grad()
actor_loss.backward()
actor_optimizer.step()

if args.autotune:
# re-use action probabilities for temperature loss
alpha_loss = (action_probs.detach() * (-log_alpha * (log_pi + target_entropy).detach())).mean()

a_optimizer.zero_grad()
alpha_loss.backward()
a_optimizer.step()
alpha = log_alpha.exp().item()

# update the target networks
if global_step % args.target_network_frequency == 0:
for param, target_param in zip(qf1.parameters(), qf1_target.parameters()):
target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)
for param, target_param in zip(qf2.parameters(), qf2_target.parameters()):
target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)

if global_step % 100 == 0:
writer.add_scalar("losses/qf1_values", qf1_a_values.mean().item(), global_step)
writer.add_scalar("losses/qf2_values", qf2_a_values.mean().item(), global_step)
writer.add_scalar("losses/qf1_loss", qf1_loss.item(), global_step)
writer.add_scalar("losses/qf2_loss", qf2_loss.item(), global_step)
writer.add_scalar("losses/qf_loss", qf_loss.item() / 2.0, global_step)
writer.add_scalar("losses/actor_loss", actor_loss.item(), global_step)
writer.add_scalar("losses/alpha", alpha, global_step)
print("SPS:", int(global_step / (time.time() - start_time)))
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
if args.autotune:
writer.add_scalar("losses/alpha_loss", alpha_loss.item(), global_step)

envs.close()
writer.close()
Loading