Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Parallel Q-Networks algorithm (PQN) #472

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
4 changes: 3 additions & 1 deletion cleanrl/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ class Args:
"""timestep to start learning"""
train_frequency: int = 10
"""the frequency of training"""
rew_scale: float = 0.1
roger-creus marked this conversation as resolved.
Show resolved Hide resolved
"""the reward scaling factor"""


def make_env(env_id, seed, idx, capture_video, run_name):
Expand Down Expand Up @@ -193,7 +195,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
for idx, trunc in enumerate(truncations):
if trunc:
real_next_obs[idx] = infos["final_observation"][idx]
rb.add(obs, real_next_obs, actions, rewards, terminations, infos)
rb.add(obs, real_next_obs, actions, rewards * args.rew_scale, terminations, infos)

# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
obs = next_obs
Expand Down
274 changes: 274 additions & 0 deletions cleanrl/pqn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,274 @@
# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/dqn/#dqnpy
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be updated

import os
import random
import time
from collections import deque
from dataclasses import dataclass

import envpool
vwxyzjn marked this conversation as resolved.
Show resolved Hide resolved
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tyro
from torch.utils.tensorboard import SummaryWriter


@dataclass
class Args:
exp_name: str = os.path.basename(__file__)[: -len(".py")]
"""the name of this experiment"""
seed: int = 1
"""seed of the experiment"""
torch_deterministic: bool = True
"""if toggled, `torch.backends.cudnn.deterministic=False`"""
cuda: bool = True
"""if toggled, cuda will be enabled by default"""
track: bool = False
"""if toggled, this experiment will be tracked with Weights and Biases"""
wandb_project_name: str = "cleanRL"
"""the wandb's project name"""
wandb_entity: str = None
"""the entity (team) of wandb's project"""
capture_video: bool = False
"""whether to capture videos of the agent performances (check out `videos` folder)"""

# Algorithm specific arguments
env_id: str = "CartPole-v1"
"""the id of the environment"""
total_timesteps: int = 500000
"""total timesteps of the experiments"""
learning_rate: float = 2.5e-4
"""the learning rate of the optimizer"""
num_envs: int = 32
"""the number of parallel game environments"""
num_steps: int = 64
"""the number of steps to run for each environment per update"""
num_minibatches: int = 16
"""the number of mini-batches"""
update_epochs: int = 2
"""the K epochs to update the policy"""
anneal_lr: bool = True
"""Toggle learning rate annealing"""
gamma: float = 0.99
"""the discount factor gamma"""
start_e: float = 1
"""the starting epsilon for exploration"""
end_e: float = 0.05
"""the ending epsilon for exploration"""
exploration_fraction: float = 0.5
"""the fraction of `total_timesteps` it takes from start_e to end_e"""
max_grad_norm: float = 10.0
"""the maximum norm for the gradient clipping"""
rew_scale: float = 0.1
"""the reward scaling factor"""
q_lambda: float = 0.65
"""the lambda for Q(lambda)"""


class RecordEpisodeStatistics(gym.Wrapper):
def __init__(self, env, deque_size=100):
super().__init__(env)
self.num_envs = getattr(env, "num_envs", 1)
self.episode_returns = None
self.episode_lengths = None

def reset(self, **kwargs):
observations = super().reset(**kwargs)
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
self.lives = np.zeros(self.num_envs, dtype=np.int32)
self.returned_episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.returned_episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
return observations

def step(self, action):
observations, rewards, dones, infos = super().step(action)
self.episode_returns += rewards
self.episode_lengths += 1
self.returned_episode_returns[:] = self.episode_returns
self.returned_episode_lengths[:] = self.episode_lengths
self.episode_returns *= 1 - dones
self.episode_lengths *= 1 - dones
infos["r"] = self.returned_episode_returns
infos["l"] = self.returned_episode_lengths
return (
observations,
rewards,
dones,
infos,
)


# ALGO LOGIC: initialize agent here:
class QNetwork(nn.Module):
def __init__(self, env):
super().__init__()

self.network = nn.Sequential(
nn.Linear(np.array(env.single_observation_space.shape).prod(), 120),
nn.LayerNorm(120),
nn.ReLU(),
nn.Linear(120, 84),
nn.LayerNorm(84),
nn.ReLU(),
nn.Linear(84, env.single_action_space.n),
)

def forward(self, x):
return self.network(x)


def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
slope = (end_e - start_e) / duration
return max(slope * t + start_e, end_e)


if __name__ == "__main__":
args = tyro.cli(Args)
args.batch_size = int(args.num_envs * args.num_steps)
args.minibatch_size = int(args.batch_size // args.num_minibatches)
args.num_iterations = args.total_timesteps // args.batch_size
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
if args.track:
import wandb

wandb.init(
project=args.wandb_project_name,
entity=args.wandb_entity,
sync_tensorboard=True,
config=vars(args),
name=run_name,
monitor_gym=True,
save_code=True,
)
writer = SummaryWriter(f"runs/{run_name}")
writer.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)

# TRY NOT TO MODIFY: seeding
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic

device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

# env setup
envs = envpool.make(
args.env_id,
env_type="gym",
num_envs=args.num_envs,
seed=args.seed,
)
envs.num_envs = args.num_envs
envs.single_action_space = envs.action_space
envs.single_observation_space = envs.observation_space
envs = RecordEpisodeStatistics(envs)
assert isinstance(envs.action_space, gym.spaces.Discrete), "only discrete action space is supported"

# agent setup
q_network = QNetwork(envs).to(device)
optimizer = optim.Adam(q_network.parameters(), lr=args.learning_rate)

# storage setup
obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device)
actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
values = torch.zeros((args.num_steps, args.num_envs)).to(device)
avg_returns = deque(maxlen=20)

global_step = 0
start_time = time.time()

# TRY NOT TO MODIFY: start the game
next_obs = torch.Tensor(envs.reset()).to(device)
next_done = torch.zeros(args.num_envs).to(device)

for iteration in range(1, args.num_iterations + 1):
# Annealing the rate if instructed to do so.
if args.anneal_lr:
frac = 1.0 - (iteration - 1.0) / args.num_iterations
lrnow = frac * args.learning_rate
optimizer.param_groups[0]["lr"] = lrnow

for step in range(0, args.num_steps):
global_step += args.num_envs
obs[step] = next_obs
dones[step] = next_done

epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step)

random_actions = torch.randint(0, envs.single_action_space.n, (args.num_envs,)).to(device)
with torch.no_grad():
q_values = q_network(next_obs)
max_actions = torch.argmax(q_values, dim=1)
values[step] = q_values[torch.arange(args.num_envs), max_actions].flatten()

explore = (torch.rand((args.num_envs,)).to(device) < epsilon)
action = torch.where(explore, random_actions, max_actions)
actions[step] = action

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, reward, next_done, info = envs.step(action.cpu().numpy())
rewards[step] = torch.tensor(reward).to(device).view(-1) * args.rew_scale
next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(next_done).to(device)

# TRY NOT TO MODIFY: record rewards for plotting purposes
for idx, d in enumerate(next_done):
if d:
print(f"global_step={global_step}, episodic_return={info['r'][idx]}")
avg_returns.append(info["r"][idx])
writer.add_scalar("charts/avg_episodic_return", np.average(avg_returns), global_step)
writer.add_scalar("charts/episodic_return", info["r"][idx], global_step)
writer.add_scalar("charts/episodic_length", info["l"][idx], global_step)

# Compute Q(lambda) targets
with torch.no_grad():
returns = torch.zeros_like(rewards).to(device)
for t in reversed(range(args.num_steps)):
if t == args.num_steps - 1:
next_value, _ = torch.max(q_network(next_obs), dim=-1)
nextnonterminal = 1.0 - next_done
returns[t] = rewards[t] + args.gamma * next_value * nextnonterminal
else:
nextnonterminal = 1.0 - dones[t + 1]
next_value = values[t + 1]
returns[t] = rewards[t] + args.gamma * (
args.q_lambda * returns[t + 1] + (1 - args.q_lambda) * next_value * nextnonterminal
)

# flatten the batch
b_obs = obs.reshape((-1,) + envs.single_observation_space.shape)
b_actions = actions.reshape((-1,) + envs.single_action_space.shape)
b_returns = returns.reshape(-1)

# Optimizing the Q-network
b_inds = np.arange(args.batch_size)
for epoch in range(args.update_epochs):
np.random.shuffle(b_inds)
for start in range(0, args.batch_size, args.minibatch_size):
end = start + args.minibatch_size
mb_inds = b_inds[start:end]

old_val = q_network(b_obs[mb_inds]).gather(1, b_actions[mb_inds].unsqueeze(-1).long()).squeeze()
loss = F.mse_loss(b_returns[mb_inds], old_val)

# optimize the model
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(q_network.parameters(), args.max_grad_norm)
optimizer.step()

writer.add_scalar("losses/td_loss", loss, global_step)
writer.add_scalar("losses/q_values", old_val.mean().item(), global_step)
print("SPS:", int(global_step / (time.time() - start_time)))
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

envs.close()
writer.close()
Loading
Loading