Skip to content

Commit

Permalink
🚀 [RofuncRL] RofuncDTrans nearly finish
Browse files Browse the repository at this point in the history
  • Loading branch information
Skylark0924 committed Aug 27, 2023
1 parent 95b4922 commit 2d11808
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 212 deletions.
237 changes: 95 additions & 142 deletions rofunc/learning/RofuncRL/agents/offline/dtrans_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import torch
import torch.nn as nn
import collections
from typing import Union, Tuple, Optional

import gym
import gymnasium
import numpy as np
import os
import torch
from omegaconf import DictConfig
from typing import Union, Tuple, Optional

import rofunc as rf
from rofunc.learning.RofuncRL.agents.base_agent import BaseAgent
from rofunc.learning.RofuncRL.processors.schedulers import KLAdaptiveRL
from rofunc.learning.RofuncRL.processors.standard_scaler import RunningStandardScaler
from rofunc.learning.RofuncRL.processors.standard_scaler import empty_preprocessor
from rofunc.learning.RofuncRL.state_encoders import encoder_map, EmptyEncoder
from rofunc.learning.RofuncRL.utils.device_utils import to_device
from rofunc.learning.RofuncRL.utils.memory import Memory
from rofunc.learning.RofuncRL.models.actor_models import ActorDTrans
from rofunc.learning.RofuncRL.utils.memory import Memory


class DTransAgent(BaseAgent):
Expand Down Expand Up @@ -62,140 +53,102 @@ def __init__(self,

super().__init__(cfg, observation_space, action_space, memory, device, experiment_dir, rofunc_logger)

self.policy = ActorDTrans()


def evaluate_episode(
env,
state_dim,
act_dim,
model,
max_ep_len=1000,
device='cuda',
target_return=None,
mode='normal',
state_mean=0.,
state_std=1.,
):
model.eval()
model.to(device=device)

state_mean = torch.from_numpy(state_mean).to(device=device)
state_std = torch.from_numpy(state_std).to(device=device)

state = env.reset()

# we keep all the histories on the device
# note that the latest action and reward will be "padding"
states = torch.from_numpy(state).reshape(1, state_dim).to(device=device, dtype=torch.float32)
actions = torch.zeros((0, act_dim), device=device, dtype=torch.float32)
rewards = torch.zeros(0, device=device, dtype=torch.float32)
target_return = torch.tensor(target_return, device=device, dtype=torch.float32)
sim_states = []

episode_return, episode_length = 0, 0
for t in range(max_ep_len):

# add padding
actions = torch.cat([actions, torch.zeros((1, act_dim), device=device)], dim=0)
rewards = torch.cat([rewards, torch.zeros(1, device=device)])

action = model.act(
(states.to(dtype=torch.float32) - state_mean) / state_std,
actions.to(dtype=torch.float32),
rewards.to(dtype=torch.float32),
target_return=target_return,
)
actions[-1] = action
action = action.detach().cpu().numpy()

state, reward, done, _ = env.step(action)

cur_state = torch.from_numpy(state).to(device=device).reshape(1, state_dim)
states = torch.cat([states, cur_state], dim=0)
rewards[-1] = reward

episode_return += reward
episode_length += 1

if done:
break

return episode_return, episode_length


def evaluate_episode_rtg(
env,
state_dim,
act_dim,
model,
max_ep_len=1000,
scale=1000.,
state_mean=0.,
state_std=1.,
device='cuda',
target_return=None,
mode='normal',
):
model.eval()
model.to(device=device)

state_mean = torch.from_numpy(state_mean).to(device=device)
state_std = torch.from_numpy(state_std).to(device=device)

state = env.reset()
if mode == 'noise':
state = state + np.random.normal(0, 0.1, size=state.shape)

# we keep all the histories on the device
# note that the latest action and reward will be "padding"
states = torch.from_numpy(state).reshape(1, state_dim).to(device=device, dtype=torch.float32)
actions = torch.zeros((0, act_dim), device=device, dtype=torch.float32)
rewards = torch.zeros(0, device=device, dtype=torch.float32)

ep_return = target_return
target_return = torch.tensor(ep_return, device=device, dtype=torch.float32).reshape(1, 1)
timesteps = torch.tensor(0, device=device, dtype=torch.long).reshape(1, 1)

sim_states = []

episode_return, episode_length = 0, 0
for t in range(max_ep_len):

# add padding
actions = torch.cat([actions, torch.zeros((1, act_dim), device=device)], dim=0)
rewards = torch.cat([rewards, torch.zeros(1, device=device)])

action = model.act(
(states.to(dtype=torch.float32) - state_mean) / state_std,
actions.to(dtype=torch.float32),
rewards.to(dtype=torch.float32),
target_return.to(dtype=torch.float32),
timesteps.to(dtype=torch.long),
)
actions[-1] = action
action = action.detach().cpu().numpy()
self.dtrans = ActorDTrans(cfg.Model, observation_space, action_space, device)
self.models = {"dtrans": self.dtrans}

# checkpoint models
self.checkpoint_modules["dtrans"] = self.dtrans
self.rofunc_logger.module(f"DTrans model: {self.dtrans}")

self.track_losses = collections.deque(maxlen=100)
self.tracking_data = collections.defaultdict(list)

state, reward, done, _ = env.step(action)
'''Get hyper-parameters from config'''
self._td_lambda = self.cfg.Agent.td_lambda
self._lr = self.cfg.Agent.lr
self._adam_eps = self.cfg.Agent.adam_eps
self._weight_decay = self.cfg.Agent.weight_decay
self._max_length = self.cfg.Agent.max_length

cur_state = torch.from_numpy(state).to(device=device).reshape(1, state_dim)
states = torch.cat([states, cur_state], dim=0)
rewards[-1] = reward
self._set_up()

if mode != 'delayed':
pred_return = target_return[0, -1] - (reward / scale)
def _set_up(self):
"""
Set up optimizer, learning rate scheduler and state/value preprocessors
"""
self.optimizer = torch.optim.AdamW(self.dtrans.parameters(), lr=self._lr, eps=self._adam_eps,
weight_decay=self._weight_decay)
if self._lr_scheduler is not None:
self.scheduler = self._lr_scheduler(self.optimizer, **self._lr_scheduler_kwargs)
self.checkpoint_modules["optimizer_policy"] = self.optimizer

self.loss_fn = lambda s_hat, a_hat, r_hat, s, a, r: torch.mean((a_hat - a) ** 2)

# set up preprocessors
super()._set_up()

def act(self, states, actions, rewards, returns_to_go, timesteps):
# we don't care about the past rewards in this model
states = states.reshape(1, -1, self.dtrans.state_dim)
actions = actions.reshape(1, -1, self.dtrans.action_dim)
returns_to_go = returns_to_go.reshape(1, -1, 1)
timesteps = timesteps.reshape(1, -1)

if self._max_length is not None:
states = states[:, -self._max_length:]
actions = actions[:, -self._max_length:]
returns_to_go = returns_to_go[:, -self._max_length:]
timesteps = timesteps[:, -self._max_length:]

# pad all tokens to sequence length
attention_mask = torch.cat([torch.zeros(self._max_length - states.shape[1]), torch.ones(states.shape[1])])
attention_mask = attention_mask.to(dtype=torch.long, device=states.device).reshape(1, -1)
states = torch.cat(
[torch.zeros((states.shape[0], self._max_length - states.shape[1], self.dtrans.state_dim),
device=states.device), states], dim=1).to(dtype=torch.float32)
actions = torch.cat(
[torch.zeros((actions.shape[0], self._max_length - actions.shape[1], self.dtrans.action_dim),
device=actions.device), actions], dim=1).to(dtype=torch.float32)
returns_to_go = torch.cat(
[torch.zeros((returns_to_go.shape[0], self._max_length - returns_to_go.shape[1], 1),
device=returns_to_go.device), returns_to_go], dim=1).to(dtype=torch.float32)
timesteps = torch.cat([torch.zeros((timesteps.shape[0], self._max_length - timesteps.shape[1]),
device=timesteps.device), timesteps], dim=1).to(dtype=torch.long)
else:
pred_return = target_return[0, -1]
target_return = torch.cat(
[target_return, pred_return.reshape(1, 1)], dim=1)
timesteps = torch.cat(
[timesteps,
torch.ones((1, 1), device=device, dtype=torch.long) * (t + 1)], dim=1)
attention_mask = None

_, action_preds, return_preds = self.dtrans(states, actions, None, returns_to_go, timesteps,
attention_mask=attention_mask)

return action_preds[0, -1]

def update_net(self):
states, actions, rewards, dones, rtg, timesteps, attention_mask = self.get_batch(self.batch_size)
action_target = torch.clone(actions)

state_preds, action_preds, reward_preds = self.dtrans.forward(
states, actions, rewards, rtg[:, :-1], timesteps, attention_mask=attention_mask,
)

act_dim = action_preds.shape[2]
action_preds = action_preds.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]
action_target = action_target.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]

loss = self.loss_fn(None, action_preds, None,
None, action_target, None)

self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), .25)
self.optimizer.step()

episode_return += reward
episode_length += 1
with torch.no_grad():
self.diagnostics['training/action_error'] = torch.mean(
(action_preds - action_target) ** 2).detach().cpu().item()

if done:
break
# update learning rate
if self._lr_scheduler is not None:
self._lr_scheduler.step()

return episode_return, episode_length
# record data
self.track_data("Loss", loss.item())
37 changes: 2 additions & 35 deletions rofunc/learning/RofuncRL/models/actor_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,17 +251,17 @@ def __init__(self, cfg: DictConfig,
self.log_std = nn.Parameter(torch.full((self.action_dim,), fill_value=-2.9), requires_grad=False)


class ActorDTrans:
class ActorDTrans(nn.Module):
def __init__(self, cfg: DictConfig,
observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]],
action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]],
state_encoder: Optional[nn.Module] = EmptyEncoder()):
super().__init__()

self.cfg = cfg
self.action_dim = get_space_dim(action_space)
self.gpt2_hidden_size = cfg.actor.gpt2_hidden_size
self.max_ep_len = cfg.actor.max_ep_len
self.max_length = cfg.actor.max_length

# state encoder
self.state_encoder = state_encoder
Expand Down Expand Up @@ -336,39 +336,6 @@ def forward(self, states, actions, rewards, returns_to_go, timesteps, attention_

return state_preds, action_preds, return_preds

def act(self, states, actions, rewards, returns_to_go, timesteps):
# we don't care about the past rewards in this model
states = states.reshape(1, -1, self.state_dim)
actions = actions.reshape(1, -1, self.action_dim)
returns_to_go = returns_to_go.reshape(1, -1, 1)
timesteps = timesteps.reshape(1, -1)

if self.max_length is not None:
states = states[:, -self.max_length:]
actions = actions[:, -self.max_length:]
returns_to_go = returns_to_go[:, -self.max_length:]
timesteps = timesteps[:, -self.max_length:]

# pad all tokens to sequence length
attention_mask = torch.cat([torch.zeros(self.max_length - states.shape[1]), torch.ones(states.shape[1])])
attention_mask = attention_mask.to(dtype=torch.long, device=states.device).reshape(1, -1)
states = torch.cat([torch.zeros((states.shape[0], self.max_length - states.shape[1], self.state_dim),
device=states.device), states], dim=1).to(dtype=torch.float32)
actions = torch.cat([torch.zeros((actions.shape[0], self.max_length - actions.shape[1], self.action_dim),
device=actions.device), actions], dim=1).to(dtype=torch.float32)
returns_to_go = torch.cat(
[torch.zeros((returns_to_go.shape[0], self.max_length - returns_to_go.shape[1], 1),
device=returns_to_go.device), returns_to_go], dim=1).to(dtype=torch.float32)
timesteps = torch.cat([torch.zeros((timesteps.shape[0], self.max_length - timesteps.shape[1]),
device=timesteps.device), timesteps], dim=1).to(dtype=torch.long)
else:
attention_mask = None

_, action_preds, return_preds = self.forward(
states, actions, None, returns_to_go, timesteps, attention_mask=attention_mask)

return action_preds[0, -1]


if __name__ == '__main__':
from omegaconf import DictConfig
Expand Down
2 changes: 1 addition & 1 deletion rofunc/learning/RofuncRL/trainers/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def get_action(self, states):

def train(self):
"""
Main training loop.
Main training loop. \n
- Reset the environment
- For each step:
- Pre-interaction
Expand Down
58 changes: 24 additions & 34 deletions rofunc/learning/RofuncRL/trainers/dtrans_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,37 +12,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from .base_trainer import Trainer


class SequenceTrainer(Trainer):

def train_step(self):
states, actions, rewards, dones, rtg, timesteps, attention_mask = self.get_batch(self.batch_size)
action_target = torch.clone(actions)

state_preds, action_preds, reward_preds = self.model.forward(
states, actions, rewards, rtg[:, :-1], timesteps, attention_mask=attention_mask,
)

act_dim = action_preds.shape[2]
action_preds = action_preds.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]
action_target = action_target.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]

loss = self.loss_fn(
None, action_preds, None,
None, action_target, None,
)

self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), .25)
self.optimizer.step()

with torch.no_grad():
self.diagnostics['training/action_error'] = torch.mean(
(action_preds - action_target) ** 2).detach().cpu().item()

return loss.detach().cpu().item()
import tqdm

from rofunc.learning.RofuncRL.agents.offline.dtrans_agent import DTransAgent
from rofunc.learning.RofuncRL.trainers.base_trainer import BaseTrainer


class DTransTrainer(BaseTrainer):
def __init__(self, cfg, env, device, env_name):
super().__init__(cfg, env, device, env_name)
self.agent = DTransAgent(cfg, self.env.observation_space, self.env.action_space, self.memory,
device, self.exp_dir, self.rofunc_logger)
self.setup_wandb()

def train(self):
"""
Main training loop.
"""
with tqdm.trange(self.maximum_steps, ncols=80, colour='green') as self.t_bar:
for _ in self.t_bar:
self.agent.update_net()

# close the logger
self.writer.close()
self.rofunc_logger.info('Training complete.')

0 comments on commit 2d11808

Please sign in to comment.