Skip to content

Commit

Permalink
🚀 [RofuncRL] Update RofuncDTrans
Browse files Browse the repository at this point in the history
  • Loading branch information
Skylark0924 committed Aug 29, 2023
1 parent a5d0a8f commit bba26b6
Show file tree
Hide file tree
Showing 8 changed files with 114 additions and 102 deletions.
6 changes: 3 additions & 3 deletions examples/learning_ml/example_felt.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ def get_orientation(df):
return demos_x, demos_taxels_pressure


demos_x, demos_taxels_pressure = data_process('../data/felt/wipe_spiral')
demos_x, demos_taxels_pressure = data_process('../data/felt/wipe_raster')

# --- TP-GMM ---
demos_x = [demo_x[:, :3] for demo_x in demos_x]
demos_x = [demos_x[0]]
demos_x = [demo_x[:500, :3] for demo_x in demos_x]
# demos_x = [demos_x[0]]
# Define the task parameters
start_xdx = [demos_x[i][0] for i in range(len(demos_x))] # TODO: change to xdx
end_xdx = [demos_x[i][-1] for i in range(len(demos_x))]
Expand Down
2 changes: 1 addition & 1 deletion examples/learning_rl/example_D4RL_RofuncRL.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def inference(custom_args):

parser = argparse.ArgumentParser()
# Available tasks: Hopper, HalfCheetah, Walker2d, Reacher2d
parser.add_argument("--task", type=str, default="Hopper")
parser.add_argument("--task", type=str, default="Walker2d")
parser.add_argument("--agent", type=str, default="dtrans") # dtrans
parser.add_argument("--sim_device", type=str, default="cuda:{}".format(gpu_id))
parser.add_argument("--rl_device", type=str, default="cuda:{}".format(gpu_id))
Expand Down
1 change: 1 addition & 0 deletions rofunc/config/learning/rl/task/HalfCheetah.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
name: HalfCheetah
1 change: 1 addition & 0 deletions rofunc/config/learning/rl/task/Reacher2d.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
name: Reacher2d
1 change: 1 addition & 0 deletions rofunc/config/learning/rl/task/Walker2d.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
name: Walker2d
4 changes: 2 additions & 2 deletions rofunc/learning/RofuncRL/agents/offline/dtrans_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import rofunc as rf
from rofunc.learning.RofuncRL.agents.base_agent import BaseAgent
from rofunc.learning.RofuncRL.models.actor_models import ActorDTrans
from rofunc.learning.RofuncRL.models.misc_models import DTrans


class DTransAgent(BaseAgent):
Expand Down Expand Up @@ -50,7 +50,7 @@ def __init__(self,

super().__init__(cfg, observation_space, action_space, None, device, experiment_dir, rofunc_logger)

self.dtrans = ActorDTrans(cfg.Model, observation_space, action_space, self.se).to(self.device)
self.dtrans = DTrans(cfg.Model, observation_space, action_space, self.se).to(self.device)
self.models = {"dtrans": self.dtrans}

# checkpoint models
Expand Down
94 changes: 0 additions & 94 deletions rofunc/learning/RofuncRL/models/actor_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
from omegaconf import DictConfig
from torch import Tensor
from torch.distributions import Beta, Normal
Expand Down Expand Up @@ -251,99 +250,6 @@ def __init__(self, cfg: DictConfig,
self.log_std = nn.Parameter(torch.full((self.action_dim,), fill_value=-2.9), requires_grad=False)


class ActorDTrans(nn.Module):
def __init__(self, cfg: DictConfig,
observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]],
action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]],
state_encoder: Optional[nn.Module] = EmptyEncoder()):
super().__init__()

self.cfg = cfg
self.action_dim = get_space_dim(action_space)
self.n_embd = cfg.actor.n_embd
self.max_ep_len = cfg.actor.max_episode_steps

# state encoder
self.state_encoder = state_encoder
if isinstance(self.state_encoder, EmptyEncoder):
self.state_dim = get_space_dim(observation_space)
else:
self.state_dim = self.state_encoder.output_dim

gpt_config = transformers.GPT2Config(
vocab_size=1, # doesn't matter -- we don't use the vocab
n_embd=self.n_embd,
n_layer=self.cfg.actor.n_layer,
n_head=self.cfg.actor.n_head,
n_inner=self.n_embd * 4,
activation_function=self.cfg.actor.activation_function,
resid_pdrop=self.cfg.actor.dropout,
attn_pdrop=self.cfg.actor.dropout,
n_positions=1024
)

self.embed_timestep = nn.Embedding(self.max_ep_len, self.n_embd)
self.embed_return = torch.nn.Linear(1, self.n_embd)
self.embed_state = torch.nn.Linear(self.state_dim, self.n_embd)
self.embed_action = torch.nn.Linear(self.action_dim, self.n_embd)
self.embed_ln = nn.LayerNorm(self.n_embd)

self.backbone_net = transformers.GPT2Model(gpt_config)

# note: we don't predict states or returns for the paper
self.predict_state = torch.nn.Linear(self.n_embd, self.state_dim)
self.predict_action = nn.Sequential(*([nn.Linear(self.n_embd, self.action_dim)] +
([nn.Tanh()] if self.cfg.use_action_out_tanh else [])))
self.predict_return = torch.nn.Linear(self.n_embd, 1)

def forward(self, states, actions, rewards, returns_to_go, timesteps, attention_mask=None):
batch_size, seq_length = states.shape[0], states.shape[1]

if attention_mask is None:
# attention mask for GPT: 1 if can be attended to, 0 if not
attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long)

# state encoder
states = self.state_encoder(states)

# embed each modality with a different head
state_embeddings = self.embed_state(states)
action_embeddings = self.embed_action(actions)
returns_embeddings = self.embed_return(returns_to_go)
time_embeddings = self.embed_timestep(timesteps)

# time embeddings are treated similar to positional embeddings
state_embeddings = state_embeddings + time_embeddings
action_embeddings = action_embeddings + time_embeddings
returns_embeddings = returns_embeddings + time_embeddings

# this makes the sequence look like (R_1, s_1, a_1, R_2, s_2, a_2, ...)
# which works nice in an autoregressive sense since states predict actions
stacked_inputs = torch.stack((returns_embeddings, state_embeddings, action_embeddings), dim=1
).permute(0, 2, 1, 3).reshape(batch_size, 3 * seq_length, self.n_embd)
stacked_inputs = self.embed_ln(stacked_inputs)

# to make the attention mask fit the stacked inputs, have to stack it as well
stacked_attention_mask = torch.stack((attention_mask, attention_mask, attention_mask), dim=1
).permute(0, 2, 1).reshape(batch_size, 3 * seq_length)

# we feed in the input embeddings (not word indices as in NLP) to the model
transformer_outputs = self.backbone_net(inputs_embeds=stacked_inputs,
attention_mask=stacked_attention_mask)
x = transformer_outputs['last_hidden_state']

# reshape x so that the second dimension corresponds to the original
# returns (0), states (1), or actions (2); i.e. x[:,1,t] is the token for s_t
x = x.reshape(batch_size, seq_length, 3, self.n_embd).permute(0, 2, 1, 3)

# get predictions
return_preds = self.predict_return(x[:, 2]) # predict next return given state and action
state_preds = self.predict_state(x[:, 2]) # predict next state given state and action
action_preds = self.predict_action(x[:, 1]) # predict next action given state

return state_preds, action_preds, return_preds


if __name__ == '__main__':
from omegaconf import DictConfig

Expand Down
107 changes: 105 additions & 2 deletions rofunc/learning/RofuncRL/models/misc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union, Tuple, Optional

import gym
import gymnasium
import torch
import torch.nn as nn
import transformers
from omegaconf import DictConfig
from torch import Tensor

from .base_models import BaseMLP
from .utils import init_layers
from rofunc.config.utils import omegaconf_to_dict
from rofunc.learning.RofuncRL.models.base_models import BaseMLP
from rofunc.learning.RofuncRL.models.utils import init_layers, get_space_dim
from rofunc.learning.RofuncRL.state_encoders.base_encoders import EmptyEncoder


class ASEDiscEnc(BaseMLP):
Expand All @@ -44,3 +52,98 @@ def get_disc(self, x: Tensor) -> Tensor:
x = self.backbone_net(x)
x = self.disc_layer(x)
return x


class DTrans(nn.Module):
def __init__(self, cfg: DictConfig,
observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]],
action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]],
state_encoder: Optional[nn.Module] = EmptyEncoder(),
cfg_name='actor'):
super().__init__()

self.cfg = cfg
self.cfg_dict = omegaconf_to_dict(self.cfg)
self.action_dim = get_space_dim(action_space)
self.n_embd = self.cfg_dict[cfg_name]["n_embd"]
self.max_ep_len = self.cfg_dict[cfg_name]["max_episode_steps"]

# state encoder
self.state_encoder = state_encoder
if isinstance(self.state_encoder, EmptyEncoder):
self.state_dim = get_space_dim(observation_space)
else:
self.state_dim = self.state_encoder.output_dim

gpt_config = transformers.GPT2Config(
vocab_size=1, # doesn't matter -- we don't use the vocab
n_embd=self.n_embd,
n_layer=self.cfg_dict[cfg_name]["n_layer"],
n_head=self.cfg_dict[cfg_name]["n_head"],
n_inner=self.n_embd * 4,
activation_function=self.cfg_dict[cfg_name]["activation_function"],
resid_pdrop=self.cfg_dict[cfg_name]["dropout"],
attn_pdrop=self.cfg_dict[cfg_name]["dropout"],
n_positions=1024
)

self.embed_timestep = nn.Embedding(self.max_ep_len, self.n_embd)
self.embed_return = torch.nn.Linear(1, self.n_embd)
self.embed_state = torch.nn.Linear(self.state_dim, self.n_embd)
self.embed_action = torch.nn.Linear(self.action_dim, self.n_embd)
self.embed_ln = nn.LayerNorm(self.n_embd)

self.backbone_net = transformers.GPT2Model(gpt_config)

# note: we don't predict states or returns for the paper
self.predict_state = torch.nn.Linear(self.n_embd, self.state_dim)
self.predict_action = nn.Sequential(*([nn.Linear(self.n_embd, self.action_dim)] +
([nn.Tanh()] if self.cfg.use_action_out_tanh else [])))
self.predict_return = torch.nn.Linear(self.n_embd, 1)

def forward(self, states, actions, rewards, returns_to_go, timesteps, attention_mask=None):
batch_size, seq_length = states.shape[0], states.shape[1]

if attention_mask is None:
# attention mask for GPT: 1 if can be attended to, 0 if not
attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long)

# state encoder
states = self.state_encoder(states)

# embed each modality with a different head
state_embeddings = self.embed_state(states)
action_embeddings = self.embed_action(actions)
returns_embeddings = self.embed_return(returns_to_go)
time_embeddings = self.embed_timestep(timesteps)

# time embeddings are treated similar to positional embeddings
state_embeddings = state_embeddings + time_embeddings
action_embeddings = action_embeddings + time_embeddings
returns_embeddings = returns_embeddings + time_embeddings

# this makes the sequence look like (R_1, s_1, a_1, R_2, s_2, a_2, ...)
# which works nice in an autoregressive sense since states predict actions
stacked_inputs = torch.stack((returns_embeddings, state_embeddings, action_embeddings), dim=1
).permute(0, 2, 1, 3).reshape(batch_size, 3 * seq_length, self.n_embd)
stacked_inputs = self.embed_ln(stacked_inputs)

# to make the attention mask fit the stacked inputs, have to stack it as well
stacked_attention_mask = torch.stack((attention_mask, attention_mask, attention_mask), dim=1
).permute(0, 2, 1).reshape(batch_size, 3 * seq_length)

# we feed in the input embeddings (not word indices as in NLP) to the model
transformer_outputs = self.backbone_net(inputs_embeds=stacked_inputs,
attention_mask=stacked_attention_mask)
x = transformer_outputs['last_hidden_state']

# reshape x so that the second dimension corresponds to the original
# returns (0), states (1), or actions (2); i.e. x[:,1,t] is the token for s_t
x = x.reshape(batch_size, seq_length, 3, self.n_embd).permute(0, 2, 1, 3)

# get predictions
return_preds = self.predict_return(x[:, 2]) # predict next return given state and action
state_preds = self.predict_state(x[:, 2]) # predict next state given state and action
action_preds = self.predict_action(x[:, 1]) # predict next action given state

return state_preds, action_preds, return_preds

0 comments on commit bba26b6

Please sign in to comment.