Skip to content

Commit

Permalink
[Feature] DDPG compatibility with compile
Browse files Browse the repository at this point in the history
ghstack-source-id: 9d88c127091ad6e711ef49e9f48aaedad15b1c05
Pull Request resolved: #2555
  • Loading branch information
vmoens committed Nov 12, 2024
1 parent d05a956 commit 049a5ba
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 56 deletions.
5 changes: 4 additions & 1 deletion sota-implementations/ddpg/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ collector:
frames_per_batch: 1000
init_env_steps: 1000
reset_at_each_iter: False
device: cpu
device:
env_per_collector: 1


Expand All @@ -39,6 +39,9 @@ network:
hidden_sizes: [256, 256]
activation: relu
noise_type: "ou" # ou or gaussian
compile: False
compile_mode:
cudagraphs: False

# logging
logger:
Expand Down
125 changes: 71 additions & 54 deletions sota-implementations/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,21 @@
The helper functions are coded in the utils.py associated with this script.
"""
import time
import warnings

import hydra

import numpy as np
import torch
import torch.cuda
import tqdm
from torchrl._utils import logger as torchrl_logger
from tensordict import TensorDict
from tensordict.nn import CudaGraphModule

from torchrl._utils import timeit

from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.objectives import group_optimizers
from torchrl.record.loggers import generate_exp_name, get_logger
from utils import (
dump_video,
Expand Down Expand Up @@ -73,8 +77,24 @@ def main(cfg: "DictConfig"): # noqa: F821
# Create DDPG loss
loss_module, target_net_updater = make_loss_module(cfg, model)

compile_mode = None
if cfg.network.compile:
if cfg.network.compile_mode not in (None, ""):
compile_mode = cfg.network.compile_mode
elif cfg.network.cudagraphs:
compile_mode = "default"
else:
compile_mode = "reduce-overhead"

# Create off-policy collector
collector = make_collector(cfg, train_env, exploration_policy)
collector = make_collector(
cfg,
train_env,
exploration_policy,
compile=cfg.network.compile,
compile_mode=compile_mode,
cudagraph=cfg.network.cudagraphs,
)

# Create replay buffer
replay_buffer = make_replay_buffer(
Expand All @@ -87,9 +107,29 @@ def main(cfg: "DictConfig"): # noqa: F821

# Create optimizers
optimizer_actor, optimizer_critic = make_optimizer(cfg, loss_module)
optimizer = group_optimizers(optimizer_actor, optimizer_critic)

def update(sampled_tensordict):
optimizer.zero_grad(set_to_none=True)

td_loss: TensorDict = loss_module(sampled_tensordict)
td_loss.sum(reduce=True).backward()
optimizer.step()

# Update qnet_target params
target_net_updater.step()
return td_loss.detach()

if cfg.loss.compile:
update = torch.compile(update, mode=compile_mode)
if cfg.loss.cudagraphs:
warnings.warn(
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
category=UserWarning,
)
update = CudaGraphModule(update, warmup=50)

# Main loop
start_time = time.time()
collected_frames = 0
pbar = tqdm.tqdm(total=cfg.collector.total_frames)

Expand All @@ -104,63 +144,42 @@ def main(cfg: "DictConfig"): # noqa: F821
eval_iter = cfg.logger.eval_iter
eval_rollout_steps = cfg.env.max_episode_steps

sampling_start = time.time()
for _, tensordict in enumerate(collector):
sampling_time = time.time() - sampling_start
c_iter = iter(collector)
for i in range(len(collector)):
with timeit("collecting"):
tensordict = next(c_iter)
# Update exploration policy
exploration_policy[1].step(tensordict.numel())

# Update weights of the inference policy
collector.update_policy_weights_()

pbar.update(tensordict.numel())

tensordict = tensordict.reshape(-1)
current_frames = tensordict.numel()
pbar.update(current_frames)

# Add to replay buffer
replay_buffer.extend(tensordict.cpu())
with timeit("rb - extend"):
tensordict = tensordict.reshape(-1)
replay_buffer.extend(tensordict)

collected_frames += current_frames

# Optimization steps
training_start = time.time()
if collected_frames >= init_random_frames:
(
actor_losses,
q_losses,
) = ([], [])
tds = []
for _ in range(num_updates):
# Sample from replay buffer
sampled_tensordict = replay_buffer.sample()
if sampled_tensordict.device != device:
sampled_tensordict = sampled_tensordict.to(
device, non_blocking=True
)
else:
sampled_tensordict = sampled_tensordict.clone()

# Update critic
q_loss, *_ = loss_module.loss_value(sampled_tensordict)
optimizer_critic.zero_grad()
q_loss.backward()
optimizer_critic.step()

# Update actor
actor_loss, *_ = loss_module.loss_actor(sampled_tensordict)
optimizer_actor.zero_grad()
actor_loss.backward()
optimizer_actor.step()

q_losses.append(q_loss.item())
actor_losses.append(actor_loss.item())

# Update qnet_target params
target_net_updater.step()
with timeit("rb - sample"):
sampled_tensordict = replay_buffer.sample().to(device)
with timeit("update"):
td_loss = update(sampled_tensordict)
tds.append(td_loss.clone())

# Update priority
if prb:
replay_buffer.update_priority(sampled_tensordict)
tds = torch.stack(tds)

training_time = time.time() - training_start
episode_end = (
tensordict["next", "done"]
if tensordict["next", "done"].any()
Expand All @@ -178,38 +197,36 @@ def main(cfg: "DictConfig"): # noqa: F821
)

if collected_frames >= init_random_frames:
metrics_to_log["train/q_loss"] = np.mean(q_losses)
metrics_to_log["train/a_loss"] = np.mean(actor_losses)
metrics_to_log["train/sampling_time"] = sampling_time
metrics_to_log["train/training_time"] = training_time
tds = TensorDict(train=tds).flatten_keys("/").mean()
metrics_to_log.update(tds.to_dict())

# Evaluation
if abs(collected_frames % eval_iter) < frames_per_batch:
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_start = time.time()
with set_exploration_type(
ExplorationType.DETERMINISTIC
), torch.no_grad(), timeit("eval"):
eval_rollout = eval_env.rollout(
eval_rollout_steps,
exploration_policy,
auto_cast_to_device=True,
break_when_any_done=True,
)
eval_env.apply(dump_video)
eval_time = time.time() - eval_start
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
metrics_to_log["eval/reward"] = eval_reward
metrics_to_log["eval/time"] = eval_time
if i % 20 == 0:
metrics_to_log.update(timeit.todict(prefix="time"))
timeit.print()
timeit.erase()

if logger is not None:
log_metrics(logger, metrics_to_log, collected_frames)
sampling_start = time.time()

collector.shutdown()
end_time = time.time()
execution_time = end_time - start_time
if not eval_env.is_closed:
eval_env.close()
if not train_env.is_closed:
train_env.close()
torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish")


if __name__ == "__main__":
Expand Down
11 changes: 10 additions & 1 deletion sota-implementations/ddpg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,14 @@ def make_environment(cfg, logger):
# ---------------------------


def make_collector(cfg, train_env, actor_model_explore):
def make_collector(
cfg,
train_env,
actor_model_explore,
compile=False,
compile_mode=None,
cudagraph=False,
):
"""Make collector."""
collector = SyncDataCollector(
train_env,
Expand All @@ -123,6 +130,8 @@ def make_collector(cfg, train_env, actor_model_explore):
reset_at_each_iter=cfg.collector.reset_at_each_iter,
total_frames=cfg.collector.total_frames,
device=cfg.collector.device,
compile_policy={"mode": compile_mode} if compile else False,
cudagraph_policy=cudagraph,
)
collector.set_seed(cfg.env.seed)
return collector
Expand Down

0 comments on commit 049a5ba

Please sign in to comment.