From 7c11d6c50ce4bf5625f933d67fda6d5fba1eb708 Mon Sep 17 00:00:00 2001 From: Junjia Liu Date: Tue, 29 Aug 2023 23:32:54 +0800 Subject: [PATCH] Update --- .../RofuncRL/agents/offline/dtrans_agent.py | 3 +-- .../learning/RofuncRL/models/misc_models.py | 21 +++++++++++++------ 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/rofunc/learning/RofuncRL/agents/offline/dtrans_agent.py b/rofunc/learning/RofuncRL/agents/offline/dtrans_agent.py index 3d480df1e..7df6e83cc 100644 --- a/rofunc/learning/RofuncRL/agents/offline/dtrans_agent.py +++ b/rofunc/learning/RofuncRL/agents/offline/dtrans_agent.py @@ -116,8 +116,7 @@ def act(self, states, actions, rewards, returns_to_go, timesteps): else: attention_mask = None - _, action_preds, return_preds = self.dtrans(states, actions, None, returns_to_go, timesteps, - attention_mask=attention_mask) + _, action_preds, return_preds = self.dtrans(states, actions, None, returns_to_go, timesteps, attention_mask) return action_preds[0, -1] diff --git a/rofunc/learning/RofuncRL/models/misc_models.py b/rofunc/learning/RofuncRL/models/misc_models.py index 50b2da67c..2ac04bc36 100644 --- a/rofunc/learning/RofuncRL/models/misc_models.py +++ b/rofunc/learning/RofuncRL/models/misc_models.py @@ -102,6 +102,15 @@ def __init__(self, cfg: DictConfig, self.predict_return = torch.nn.Linear(self.n_embd, 1) def forward(self, states, actions, rewards, returns_to_go, timesteps, attention_mask=None): + """ + :param states: [B, T, C, H, W] or [B, T, Value] + :param actions: [B, T, Action] + :param rewards: [B, T, 1] + :param returns_to_go: [B, T, 1] + :param timesteps: [B, T] + :param attention_mask: [B, T] + :return: + """ batch_size, seq_length = states.shape[0], states.shape[1] if attention_mask is None: @@ -112,10 +121,10 @@ def forward(self, states, actions, rewards, returns_to_go, timesteps, attention_ states = self.state_encoder(states) # embed each modality with a different head - state_embeddings = self.embed_state(states) - action_embeddings = self.embed_action(actions) - returns_embeddings = self.embed_return(returns_to_go) - time_embeddings = self.embed_timestep(timesteps) + state_embeddings = self.embed_state(states) # [B, T, n_embd] + action_embeddings = self.embed_action(actions) # [B, T, n_embd] + returns_embeddings = self.embed_return(returns_to_go) # [B, T, n_embd] + time_embeddings = self.embed_timestep(timesteps) # [B, T, n_embd] # time embeddings are treated similar to positional embeddings state_embeddings = state_embeddings + time_embeddings @@ -124,6 +133,7 @@ def forward(self, states, actions, rewards, returns_to_go, timesteps, attention_ # this makes the sequence look like (R_1, s_1, a_1, R_2, s_2, a_2, ...) # which works nice in an autoregressive sense since states predict actions + # [B, 3, T, n_embd] -> [B, T, 3, n_embd] -> [B, T*3, n_embd] stacked_inputs = torch.stack((returns_embeddings, state_embeddings, action_embeddings), dim=1 ).permute(0, 2, 1, 3).reshape(batch_size, 3 * seq_length, self.n_embd) stacked_inputs = self.embed_ln(stacked_inputs) @@ -133,8 +143,7 @@ def forward(self, states, actions, rewards, returns_to_go, timesteps, attention_ ).permute(0, 2, 1).reshape(batch_size, 3 * seq_length) # we feed in the input embeddings (not word indices as in NLP) to the model - transformer_outputs = self.backbone_net(inputs_embeds=stacked_inputs, - attention_mask=stacked_attention_mask) + transformer_outputs = self.backbone_net(inputs_embeds=stacked_inputs, attention_mask=stacked_attention_mask) x = transformer_outputs['last_hidden_state'] # reshape x so that the second dimension corresponds to the original