Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
Skylark0924 committed Aug 29, 2023
1 parent bba26b6 commit 7c11d6c
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 8 deletions.
3 changes: 1 addition & 2 deletions rofunc/learning/RofuncRL/agents/offline/dtrans_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,7 @@ def act(self, states, actions, rewards, returns_to_go, timesteps):
else:
attention_mask = None

_, action_preds, return_preds = self.dtrans(states, actions, None, returns_to_go, timesteps,
attention_mask=attention_mask)
_, action_preds, return_preds = self.dtrans(states, actions, None, returns_to_go, timesteps, attention_mask)

return action_preds[0, -1]

Expand Down
21 changes: 15 additions & 6 deletions rofunc/learning/RofuncRL/models/misc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,15 @@ def __init__(self, cfg: DictConfig,
self.predict_return = torch.nn.Linear(self.n_embd, 1)

def forward(self, states, actions, rewards, returns_to_go, timesteps, attention_mask=None):
"""
:param states: [B, T, C, H, W] or [B, T, Value]
:param actions: [B, T, Action]
:param rewards: [B, T, 1]
:param returns_to_go: [B, T, 1]
:param timesteps: [B, T]
:param attention_mask: [B, T]
:return:
"""
batch_size, seq_length = states.shape[0], states.shape[1]

if attention_mask is None:
Expand All @@ -112,10 +121,10 @@ def forward(self, states, actions, rewards, returns_to_go, timesteps, attention_
states = self.state_encoder(states)

# embed each modality with a different head
state_embeddings = self.embed_state(states)
action_embeddings = self.embed_action(actions)
returns_embeddings = self.embed_return(returns_to_go)
time_embeddings = self.embed_timestep(timesteps)
state_embeddings = self.embed_state(states) # [B, T, n_embd]
action_embeddings = self.embed_action(actions) # [B, T, n_embd]
returns_embeddings = self.embed_return(returns_to_go) # [B, T, n_embd]
time_embeddings = self.embed_timestep(timesteps) # [B, T, n_embd]

# time embeddings are treated similar to positional embeddings
state_embeddings = state_embeddings + time_embeddings
Expand All @@ -124,6 +133,7 @@ def forward(self, states, actions, rewards, returns_to_go, timesteps, attention_

# this makes the sequence look like (R_1, s_1, a_1, R_2, s_2, a_2, ...)
# which works nice in an autoregressive sense since states predict actions
# [B, 3, T, n_embd] -> [B, T, 3, n_embd] -> [B, T*3, n_embd]
stacked_inputs = torch.stack((returns_embeddings, state_embeddings, action_embeddings), dim=1
).permute(0, 2, 1, 3).reshape(batch_size, 3 * seq_length, self.n_embd)
stacked_inputs = self.embed_ln(stacked_inputs)
Expand All @@ -133,8 +143,7 @@ def forward(self, states, actions, rewards, returns_to_go, timesteps, attention_
).permute(0, 2, 1).reshape(batch_size, 3 * seq_length)

# we feed in the input embeddings (not word indices as in NLP) to the model
transformer_outputs = self.backbone_net(inputs_embeds=stacked_inputs,
attention_mask=stacked_attention_mask)
transformer_outputs = self.backbone_net(inputs_embeds=stacked_inputs, attention_mask=stacked_attention_mask)
x = transformer_outputs['last_hidden_state']

# reshape x so that the second dimension corresponds to the original
Expand Down

0 comments on commit 7c11d6c

Please sign in to comment.