Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change adversarial algorithms to collect rollouts first #731

Draft
wants to merge 10 commits into
base: master
Choose a base branch
from
157 changes: 141 additions & 16 deletions src/imitation/algorithms/adversarial/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,17 @@
import torch as th
import torch.utils.tensorboard as thboard
import tqdm
from stable_baselines3.common import base_class, on_policy_algorithm, policies, vec_env
from stable_baselines3.common import base_class
from stable_baselines3.common import buffers as sb3_buffers
from stable_baselines3.common import on_policy_algorithm, policies, type_aliases
from stable_baselines3.common import utils as sb3_utils
from stable_baselines3.common import vec_env
from stable_baselines3.sac import policies as sac_policies
from torch.nn import functional as F

from imitation.algorithms import base
from imitation.data import buffer, rollout, types, wrappers
from imitation.policies import replay_buffer_wrapper
from imitation.rewards import reward_nets, reward_wrapper
from imitation.util import logger, networks, util

Expand Down Expand Up @@ -246,6 +251,38 @@ def __init__(
else:
self.gen_train_timesteps = gen_train_timesteps

self.is_gen_on_policy = isinstance(
self.gen_algo, on_policy_algorithm.OnPolicyAlgorithm
)
if self.is_gen_on_policy:
rollout_buffer = self.gen_algo.rollout_buffer
self.gen_algo.rollout_buffer = (
replay_buffer_wrapper.RolloutBufferRewardWrapper(
buffer_size=self.gen_train_timesteps // rollout_buffer.n_envs,
observation_space=rollout_buffer.observation_space,
action_space=rollout_buffer.action_space,
rollout_buffer_class=rollout_buffer.__class__,
reward_fn=self.reward_train.predict_processed,
device=rollout_buffer.device,
gae_lambda=rollout_buffer.gae_lambda,
gamma=rollout_buffer.gamma,
n_envs=rollout_buffer.n_envs,
)
)
else:
replay_buffer = self.gen_algo.replay_buffer
self.gen_algo.replay_buffer = (
replay_buffer_wrapper.ReplayBufferRewardWrapper(
buffer_size=self.gen_train_timesteps,
observation_space=replay_buffer.observation_space,
action_space=replay_buffer.action_space,
replay_buffer_class=sb3_buffers.ReplayBuffer,
reward_fn=self.reward_train.predict_processed,
device=replay_buffer.device,
n_envs=replay_buffer.n_envs,
)
)

if gen_replay_buffer_capacity is None:
gen_replay_buffer_capacity = self.gen_train_timesteps
self._gen_replay_buffer = buffer.ReplayBuffer(
Expand Down Expand Up @@ -382,41 +419,126 @@ def train_disc(

return train_stats

def train_gen(
def collect_rollouts(
self,
total_timesteps: Optional[int] = None,
callback: type_aliases.MaybeCallback = None,
learn_kwargs: Optional[Mapping] = None,
) -> None:
"""Trains the generator to maximize the discriminator loss.

After the end of training populates the generator replay buffer (used in
discriminator training) with `self.disc_batch_size` transitions.
):
"""Collect rollouts.

Args:
total_timesteps: The number of transitions to sample from
`self.venv_train` during training. By default,
`self.gen_train_timesteps`.
callback: Callback that will be called at each step
(and at the beginning and end of the rollout)
learn_kwargs: kwargs for the Stable Baselines `RLModel.learn()`
method.
"""
if total_timesteps is None:
total_timesteps = self.gen_train_timesteps
if learn_kwargs is None:
learn_kwargs = {}

with self.logger.accumulate_means("gen"):
self.gen_algo.learn(
total_timesteps=total_timesteps,
reset_num_timesteps=False,
callback=self.gen_callback,
if total_timesteps is None:
total_timesteps = self.gen_train_timesteps

# total timesteps should be per env
total_timesteps = total_timesteps // self.gen_algo.n_envs
# NOTE (Taufeeque): call setup_learn or not?
if "eval_env" not in learn_kwargs:
total_timesteps, callback = self.gen_algo._setup_learn(
total_timesteps,
eval_env=None,
callback=callback,
**learn_kwargs,
)
self._global_step += 1
else:
total_timesteps, callback = self.gen_algo._setup_learn(
total_timesteps,
callback=callback,
**learn_kwargs,
)
callback.on_training_start(locals(), globals())
if self.is_gen_on_policy:
self.gen_algo.collect_rollouts(
self.gen_algo.env,
callback,
self.gen_algo.rollout_buffer,
n_rollout_steps=total_timesteps,
)
rollouts = None
else:
self.gen_algo.train_freq = total_timesteps
self.gen_algo._convert_train_freq()
rollouts = self.gen_algo.collect_rollouts(
self.gen_algo.env,
train_freq=self.gen_algo.train_freq,
action_noise=self.gen_algo.action_noise,
callback=callback,
learning_starts=self.gen_algo.learning_starts,
replay_buffer=self.gen_algo.replay_buffer,
)

if self.is_gen_on_policy:
if (
len(self.gen_algo.ep_info_buffer) > 0
and len(self.gen_algo.ep_info_buffer[0]) > 0
):
self.logger.record(
"rollout/ep_rew_mean",
sb3_utils.safe_mean(
[ep_info["r"] for ep_info in self.gen_algo.ep_info_buffer]
),
)
self.logger.record(
"rollout/ep_len_mean",
sb3_utils.safe_mean(
[ep_info["l"] for ep_info in self.gen_algo.ep_info_buffer]
),
)
self.logger.record(
"time/total_timesteps",
self.gen_algo.num_timesteps,
exclude="tensorboard",
)
else:
self.gen_algo._dump_logs()

gen_trajs, ep_lens = self.venv_buffering.pop_trajectories()
self._check_fixed_horizon(ep_lens)
gen_samples = rollout.flatten_trajectories_with_rew(gen_trajs)
self._gen_replay_buffer.store(gen_samples)
callback.on_training_end()
return rollouts

def train_gen(
self,
rollouts,
) -> None:
"""Trains the generator to maximize the discriminator loss.

After the end of training populates the generator replay buffer (used in
discriminator training) with `self.disc_batch_size` transitions.
"""
with self.logger.accumulate_means("gen"):
# self.gen_algo.learn(
# total_timesteps=total_timesteps,
# reset_num_timesteps=False,
# callback=self.gen_callback,
# **learn_kwargs,
# )
if self.is_gen_on_policy:
self.gen_algo.train()
else:
if self.gen_algo.gradient_steps >= 0:
gradient_steps = self.gen_algo.gradient_steps
else:
gradient_steps = rollouts.episode_timesteps
self.gen_algo.train(
batch_size=self.gen_algo.batch_size,
gradient_steps=gradient_steps,
)
self._global_step += 1

def train(
self,
Expand Down Expand Up @@ -445,11 +567,14 @@ def train(
f"total_timesteps={total_timesteps})!"
)
for r in tqdm.tqdm(range(0, n_rounds), desc="round"):
self.train_gen(self.gen_train_timesteps)
rollouts = self.collect_rollouts(
self.gen_train_timesteps, self.gen_callback
)
for _ in range(self.n_disc_updates_per_round):
with networks.training(self.reward_train):
# switch to training mode (affects dropout, normalization)
self.train_disc()
self.train_gen(rollouts)
if callback:
callback(r)
self.logger.dump(self._global_step)
Expand Down
140 changes: 136 additions & 4 deletions src/imitation/policies/replay_buffer_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
from typing import Mapping, Type

import numpy as np
import torch as th
from gym import spaces
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.buffers import BaseBuffer, ReplayBuffer, RolloutBuffer
from stable_baselines3.common.type_aliases import ReplayBufferSamples

from imitation.rewards.reward_function import RewardFn
from imitation.util import util


def _samples_to_reward_fn_input(
def _replay_samples_to_reward_fn_input(
samples: ReplayBufferSamples,
) -> Mapping[str, np.ndarray]:
"""Convert a sample from a replay buffer to a numpy array."""
Expand All @@ -23,6 +24,18 @@ def _samples_to_reward_fn_input(
)


def _rollout_samples_to_reward_fn_input(
buffer: RolloutBuffer,
) -> Mapping[str, np.ndarray]:
"""Convert a sample from a rollout buffer to a numpy array."""
return dict(
state=buffer.observations,
action=buffer.actions,
next_state=buffer.next_observations,
done=buffer.dones,
)


class ReplayBufferRewardWrapper(ReplayBuffer):
"""Relabel the rewards in transitions sampled from a ReplayBuffer."""

Expand Down Expand Up @@ -50,7 +63,9 @@ def __init__(
# DictReplayBuffer because the current RewardFn only takes in NumPy array-based
# inputs, and SAC is the only use case for ReplayBuffer relabeling. See:
# https://github.com/HumanCompatibleAI/imitation/pull/459#issuecomment-1201997194
assert replay_buffer_class is ReplayBuffer, "only ReplayBuffer is supported"
assert (
replay_buffer_class is ReplayBuffer
), f"only ReplayBuffer is supported: given {replay_buffer_class}"
assert not isinstance(observation_space, spaces.Dict)
self.replay_buffer = replay_buffer_class(
buffer_size,
Expand Down Expand Up @@ -80,7 +95,7 @@ def full(self, full: bool):

def sample(self, *args, **kwargs):
samples = self.replay_buffer.sample(*args, **kwargs)
rewards = self.reward_fn(**_samples_to_reward_fn_input(samples))
rewards = self.reward_fn(**_replay_samples_to_reward_fn_input(samples))
shape = samples.rewards.shape
device = samples.rewards.device
rewards_th = util.safe_to_tensor(rewards).reshape(shape).to(device)
Expand All @@ -101,3 +116,120 @@ def _get_samples(self):
"_get_samples() is intentionally not implemented."
"This method should not be called.",
)


class RolloutBufferRewardWrapper(BaseBuffer):
"""Relabel the rewards in transitions sampled from a RolloutBuffer."""

def __init__(
self,
buffer_size: int,
observation_space: spaces.Space,
action_space: spaces.Space,
*,
rollout_buffer_class: Type[RolloutBuffer],
reward_fn: RewardFn,
**kwargs,
):
"""Builds RolloutBufferRewardWrapper.

Args:
buffer_size: Max number of elements in the buffer
observation_space: Observation space
action_space: Action space
rollout_buffer_class: Class of the rollout buffer.
reward_fn: Reward function for reward relabeling.
**kwargs: keyword arguments for RolloutBuffer.
"""
# Note(yawen-d): we directly inherit RolloutBuffer and leave out the case of
# DictRolloutBuffer because the current RewardFn only takes in NumPy array-based
# inputs, and GAIL/AIRL is the only use case for RolloutBuffer relabeling. See:
# https://github.com/HumanCompatibleAI/imitation/pull/459#issuecomment-1201997194
assert rollout_buffer_class is RolloutBuffer, "only RolloutBuffer is supported"
assert not isinstance(observation_space, spaces.Dict)
self.rollout_buffer = rollout_buffer_class(
buffer_size,
observation_space,
action_space,
**kwargs,
)
self.reward_fn = reward_fn
_base_kwargs = {k: v for k, v in kwargs.items() if k in ["device", "n_envs"]}
super().__init__(buffer_size, observation_space, action_space, **_base_kwargs)

@property
def pos(self) -> int:
return self.rollout_buffer.pos

@property
def values(self):
return self.rollout_buffer.values

@property
def observations(self):
return self.rollout_buffer.observations

@property
def actions(self):
return self.rollout_buffer.actions

@property
def log_probs(self):
return self.rollout_buffer.log_probs

@property
def advantages(self):
return self.rollout_buffer.advantages

@property
def rewards(self):
return self.rollout_buffer.rewards

@property
def returns(self):
return self.rollout_buffer.returns

@pos.setter
def pos(self, pos: int):
self.rollout_buffer.pos = pos

@property
def full(self) -> bool:
return self.rollout_buffer.full

@full.setter
def full(self, full: bool):
self.rollout_buffer.full = full

def reset(self):
self.rollout_buffer.reset()

def get(self, *args, **kwargs):
if not self.rollout_buffer.generator_ready:
input_dict = _rollout_samples_to_reward_fn_input(self.rollout_buffer)
rewards = np.zeros_like(self.rollout_buffer.rewards)
for i in range(self.buffer_size):
rewards[i] = self.reward_fn(**{k: v[i] for k, v in input_dict.items()})

self.rollout_buffer.rewards = rewards
self.rollout_buffer.compute_returns_and_advantage(
self.last_values, self.last_dones
)
ret = self.rollout_buffer.get(*args, **kwargs)
return ret

def add(self, *args, **kwargs):
self.rollout_buffer.add(*args, **kwargs)

def _get_samples(self):
raise NotImplementedError(
"_get_samples() is intentionally not implemented."
"This method should not be called.",
)

def compute_returns_and_advantage(
self, last_values: th.Tensor, dones: np.ndarray
) -> None:
self.last_values = last_values
self.last_dones = dones
self.rollout_buffer.compute_returns_and_advantage(last_values, dones)
Loading