diff --git a/examples/learning_ml/example_felt.py b/examples/learning_ml/example_felt.py index b35f4d542..684c6902a 100644 --- a/examples/learning_ml/example_felt.py +++ b/examples/learning_ml/example_felt.py @@ -57,11 +57,11 @@ def get_orientation(df): return demos_x, demos_taxels_pressure -demos_x, demos_taxels_pressure = data_process('../data/felt/wipe_spiral') +demos_x, demos_taxels_pressure = data_process('../data/felt/wipe_raster') # --- TP-GMM --- -demos_x = [demo_x[:, :3] for demo_x in demos_x] -demos_x = [demos_x[0]] +demos_x = [demo_x[:500, :3] for demo_x in demos_x] +# demos_x = [demos_x[0]] # Define the task parameters start_xdx = [demos_x[i][0] for i in range(len(demos_x))] # TODO: change to xdx end_xdx = [demos_x[i][-1] for i in range(len(demos_x))] diff --git a/examples/learning_rl/example_D4RL_RofuncRL.py b/examples/learning_rl/example_D4RL_RofuncRL.py index 7881cf24e..26167886d 100644 --- a/examples/learning_rl/example_D4RL_RofuncRL.py +++ b/examples/learning_rl/example_D4RL_RofuncRL.py @@ -86,7 +86,7 @@ def inference(custom_args): parser = argparse.ArgumentParser() # Available tasks: Hopper, HalfCheetah, Walker2d, Reacher2d - parser.add_argument("--task", type=str, default="Hopper") + parser.add_argument("--task", type=str, default="Walker2d") parser.add_argument("--agent", type=str, default="dtrans") # dtrans parser.add_argument("--sim_device", type=str, default="cuda:{}".format(gpu_id)) parser.add_argument("--rl_device", type=str, default="cuda:{}".format(gpu_id)) diff --git a/rofunc/config/learning/rl/task/HalfCheetah.yaml b/rofunc/config/learning/rl/task/HalfCheetah.yaml new file mode 100644 index 000000000..08b3bbf39 --- /dev/null +++ b/rofunc/config/learning/rl/task/HalfCheetah.yaml @@ -0,0 +1 @@ +name: HalfCheetah \ No newline at end of file diff --git a/rofunc/config/learning/rl/task/Reacher2d.yaml b/rofunc/config/learning/rl/task/Reacher2d.yaml new file mode 100644 index 000000000..7aefacf42 --- /dev/null +++ b/rofunc/config/learning/rl/task/Reacher2d.yaml @@ -0,0 +1 @@ +name: Reacher2d \ No newline at end of file diff --git a/rofunc/config/learning/rl/task/Walker2d.yaml b/rofunc/config/learning/rl/task/Walker2d.yaml new file mode 100644 index 000000000..f6f50d105 --- /dev/null +++ b/rofunc/config/learning/rl/task/Walker2d.yaml @@ -0,0 +1 @@ +name: Walker2d \ No newline at end of file diff --git a/rofunc/learning/RofuncRL/agents/offline/dtrans_agent.py b/rofunc/learning/RofuncRL/agents/offline/dtrans_agent.py index 4d752afbc..3d480df1e 100644 --- a/rofunc/learning/RofuncRL/agents/offline/dtrans_agent.py +++ b/rofunc/learning/RofuncRL/agents/offline/dtrans_agent.py @@ -22,7 +22,7 @@ import rofunc as rf from rofunc.learning.RofuncRL.agents.base_agent import BaseAgent -from rofunc.learning.RofuncRL.models.actor_models import ActorDTrans +from rofunc.learning.RofuncRL.models.misc_models import DTrans class DTransAgent(BaseAgent): @@ -50,7 +50,7 @@ def __init__(self, super().__init__(cfg, observation_space, action_space, None, device, experiment_dir, rofunc_logger) - self.dtrans = ActorDTrans(cfg.Model, observation_space, action_space, self.se).to(self.device) + self.dtrans = DTrans(cfg.Model, observation_space, action_space, self.se).to(self.device) self.models = {"dtrans": self.dtrans} # checkpoint models diff --git a/rofunc/learning/RofuncRL/models/actor_models.py b/rofunc/learning/RofuncRL/models/actor_models.py index 622733edc..f514eec93 100644 --- a/rofunc/learning/RofuncRL/models/actor_models.py +++ b/rofunc/learning/RofuncRL/models/actor_models.py @@ -19,7 +19,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -import transformers from omegaconf import DictConfig from torch import Tensor from torch.distributions import Beta, Normal @@ -251,99 +250,6 @@ def __init__(self, cfg: DictConfig, self.log_std = nn.Parameter(torch.full((self.action_dim,), fill_value=-2.9), requires_grad=False) -class ActorDTrans(nn.Module): - def __init__(self, cfg: DictConfig, - observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]], - action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]], - state_encoder: Optional[nn.Module] = EmptyEncoder()): - super().__init__() - - self.cfg = cfg - self.action_dim = get_space_dim(action_space) - self.n_embd = cfg.actor.n_embd - self.max_ep_len = cfg.actor.max_episode_steps - - # state encoder - self.state_encoder = state_encoder - if isinstance(self.state_encoder, EmptyEncoder): - self.state_dim = get_space_dim(observation_space) - else: - self.state_dim = self.state_encoder.output_dim - - gpt_config = transformers.GPT2Config( - vocab_size=1, # doesn't matter -- we don't use the vocab - n_embd=self.n_embd, - n_layer=self.cfg.actor.n_layer, - n_head=self.cfg.actor.n_head, - n_inner=self.n_embd * 4, - activation_function=self.cfg.actor.activation_function, - resid_pdrop=self.cfg.actor.dropout, - attn_pdrop=self.cfg.actor.dropout, - n_positions=1024 - ) - - self.embed_timestep = nn.Embedding(self.max_ep_len, self.n_embd) - self.embed_return = torch.nn.Linear(1, self.n_embd) - self.embed_state = torch.nn.Linear(self.state_dim, self.n_embd) - self.embed_action = torch.nn.Linear(self.action_dim, self.n_embd) - self.embed_ln = nn.LayerNorm(self.n_embd) - - self.backbone_net = transformers.GPT2Model(gpt_config) - - # note: we don't predict states or returns for the paper - self.predict_state = torch.nn.Linear(self.n_embd, self.state_dim) - self.predict_action = nn.Sequential(*([nn.Linear(self.n_embd, self.action_dim)] + - ([nn.Tanh()] if self.cfg.use_action_out_tanh else []))) - self.predict_return = torch.nn.Linear(self.n_embd, 1) - - def forward(self, states, actions, rewards, returns_to_go, timesteps, attention_mask=None): - batch_size, seq_length = states.shape[0], states.shape[1] - - if attention_mask is None: - # attention mask for GPT: 1 if can be attended to, 0 if not - attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long) - - # state encoder - states = self.state_encoder(states) - - # embed each modality with a different head - state_embeddings = self.embed_state(states) - action_embeddings = self.embed_action(actions) - returns_embeddings = self.embed_return(returns_to_go) - time_embeddings = self.embed_timestep(timesteps) - - # time embeddings are treated similar to positional embeddings - state_embeddings = state_embeddings + time_embeddings - action_embeddings = action_embeddings + time_embeddings - returns_embeddings = returns_embeddings + time_embeddings - - # this makes the sequence look like (R_1, s_1, a_1, R_2, s_2, a_2, ...) - # which works nice in an autoregressive sense since states predict actions - stacked_inputs = torch.stack((returns_embeddings, state_embeddings, action_embeddings), dim=1 - ).permute(0, 2, 1, 3).reshape(batch_size, 3 * seq_length, self.n_embd) - stacked_inputs = self.embed_ln(stacked_inputs) - - # to make the attention mask fit the stacked inputs, have to stack it as well - stacked_attention_mask = torch.stack((attention_mask, attention_mask, attention_mask), dim=1 - ).permute(0, 2, 1).reshape(batch_size, 3 * seq_length) - - # we feed in the input embeddings (not word indices as in NLP) to the model - transformer_outputs = self.backbone_net(inputs_embeds=stacked_inputs, - attention_mask=stacked_attention_mask) - x = transformer_outputs['last_hidden_state'] - - # reshape x so that the second dimension corresponds to the original - # returns (0), states (1), or actions (2); i.e. x[:,1,t] is the token for s_t - x = x.reshape(batch_size, seq_length, 3, self.n_embd).permute(0, 2, 1, 3) - - # get predictions - return_preds = self.predict_return(x[:, 2]) # predict next return given state and action - state_preds = self.predict_state(x[:, 2]) # predict next state given state and action - action_preds = self.predict_action(x[:, 1]) # predict next action given state - - return state_preds, action_preds, return_preds - - if __name__ == '__main__': from omegaconf import DictConfig diff --git a/rofunc/learning/RofuncRL/models/misc_models.py b/rofunc/learning/RofuncRL/models/misc_models.py index 05bc8942b..50b2da67c 100644 --- a/rofunc/learning/RofuncRL/models/misc_models.py +++ b/rofunc/learning/RofuncRL/models/misc_models.py @@ -12,12 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Union, Tuple, Optional + +import gym +import gymnasium +import torch import torch.nn as nn +import transformers from omegaconf import DictConfig from torch import Tensor -from .base_models import BaseMLP -from .utils import init_layers +from rofunc.config.utils import omegaconf_to_dict +from rofunc.learning.RofuncRL.models.base_models import BaseMLP +from rofunc.learning.RofuncRL.models.utils import init_layers, get_space_dim +from rofunc.learning.RofuncRL.state_encoders.base_encoders import EmptyEncoder class ASEDiscEnc(BaseMLP): @@ -44,3 +52,98 @@ def get_disc(self, x: Tensor) -> Tensor: x = self.backbone_net(x) x = self.disc_layer(x) return x + + +class DTrans(nn.Module): + def __init__(self, cfg: DictConfig, + observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]], + action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]], + state_encoder: Optional[nn.Module] = EmptyEncoder(), + cfg_name='actor'): + super().__init__() + + self.cfg = cfg + self.cfg_dict = omegaconf_to_dict(self.cfg) + self.action_dim = get_space_dim(action_space) + self.n_embd = self.cfg_dict[cfg_name]["n_embd"] + self.max_ep_len = self.cfg_dict[cfg_name]["max_episode_steps"] + + # state encoder + self.state_encoder = state_encoder + if isinstance(self.state_encoder, EmptyEncoder): + self.state_dim = get_space_dim(observation_space) + else: + self.state_dim = self.state_encoder.output_dim + + gpt_config = transformers.GPT2Config( + vocab_size=1, # doesn't matter -- we don't use the vocab + n_embd=self.n_embd, + n_layer=self.cfg_dict[cfg_name]["n_layer"], + n_head=self.cfg_dict[cfg_name]["n_head"], + n_inner=self.n_embd * 4, + activation_function=self.cfg_dict[cfg_name]["activation_function"], + resid_pdrop=self.cfg_dict[cfg_name]["dropout"], + attn_pdrop=self.cfg_dict[cfg_name]["dropout"], + n_positions=1024 + ) + + self.embed_timestep = nn.Embedding(self.max_ep_len, self.n_embd) + self.embed_return = torch.nn.Linear(1, self.n_embd) + self.embed_state = torch.nn.Linear(self.state_dim, self.n_embd) + self.embed_action = torch.nn.Linear(self.action_dim, self.n_embd) + self.embed_ln = nn.LayerNorm(self.n_embd) + + self.backbone_net = transformers.GPT2Model(gpt_config) + + # note: we don't predict states or returns for the paper + self.predict_state = torch.nn.Linear(self.n_embd, self.state_dim) + self.predict_action = nn.Sequential(*([nn.Linear(self.n_embd, self.action_dim)] + + ([nn.Tanh()] if self.cfg.use_action_out_tanh else []))) + self.predict_return = torch.nn.Linear(self.n_embd, 1) + + def forward(self, states, actions, rewards, returns_to_go, timesteps, attention_mask=None): + batch_size, seq_length = states.shape[0], states.shape[1] + + if attention_mask is None: + # attention mask for GPT: 1 if can be attended to, 0 if not + attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long) + + # state encoder + states = self.state_encoder(states) + + # embed each modality with a different head + state_embeddings = self.embed_state(states) + action_embeddings = self.embed_action(actions) + returns_embeddings = self.embed_return(returns_to_go) + time_embeddings = self.embed_timestep(timesteps) + + # time embeddings are treated similar to positional embeddings + state_embeddings = state_embeddings + time_embeddings + action_embeddings = action_embeddings + time_embeddings + returns_embeddings = returns_embeddings + time_embeddings + + # this makes the sequence look like (R_1, s_1, a_1, R_2, s_2, a_2, ...) + # which works nice in an autoregressive sense since states predict actions + stacked_inputs = torch.stack((returns_embeddings, state_embeddings, action_embeddings), dim=1 + ).permute(0, 2, 1, 3).reshape(batch_size, 3 * seq_length, self.n_embd) + stacked_inputs = self.embed_ln(stacked_inputs) + + # to make the attention mask fit the stacked inputs, have to stack it as well + stacked_attention_mask = torch.stack((attention_mask, attention_mask, attention_mask), dim=1 + ).permute(0, 2, 1).reshape(batch_size, 3 * seq_length) + + # we feed in the input embeddings (not word indices as in NLP) to the model + transformer_outputs = self.backbone_net(inputs_embeds=stacked_inputs, + attention_mask=stacked_attention_mask) + x = transformer_outputs['last_hidden_state'] + + # reshape x so that the second dimension corresponds to the original + # returns (0), states (1), or actions (2); i.e. x[:,1,t] is the token for s_t + x = x.reshape(batch_size, seq_length, 3, self.n_embd).permute(0, 2, 1, 3) + + # get predictions + return_preds = self.predict_return(x[:, 2]) # predict next return given state and action + state_preds = self.predict_state(x[:, 2]) # predict next state given state and action + action_preds = self.predict_action(x[:, 1]) # predict next action given state + + return state_preds, action_preds, return_preds