Skip to content

Commit

Permalink
bug fix in custom inference to support alibi
Browse files Browse the repository at this point in the history
Signed-off-by: Paarth Neekhara <[email protected]>
  • Loading branch information
paarthneekhara committed Dec 1, 2023
1 parent 981c34b commit 746a19d
Showing 1 changed file with 11 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1961,9 +1961,16 @@ def custom_autoregressive_inference(self, batch, prompt_len, pred_steps=500, sid
# curr_position_ids = batch['position_ids'][sidx:sidx+1,:prompt_len]
dummy_position_ids = torch.arange(0, prompt_len+pred_steps, device=batch['position_ids'].device).unsqueeze(0)
curr_position_ids = dummy_position_ids[:, :prompt_len]

curr_attention_mask = None
if batch['attention_mask'] is not None:
curr_attention_mask = batch['attention_mask'][sidx:sidx+1,:,:prompt_len,:prompt_len]
dummy_attention_mask = torch.tril(torch.ones((1, prompt_len+pred_steps+1, prompt_len+pred_steps+1))).view(
1, 1, prompt_len+pred_steps+1, prompt_len+pred_steps+1
)
dummy_attention_mask = dummy_attention_mask < 0.5
dummy_attention_mask = dummy_attention_mask.to(batch['attention_mask'].device)
curr_attention_mask = dummy_attention_mask[:,:,:prompt_len,:prompt_len]
# curr_attention_mask = batch['attention_mask'][sidx:sidx+1,:,:prompt_len,:prompt_len]
curr_speech_mask = batch['speech_mask'][sidx:sidx+1,:prompt_len]

all_preds = []
Expand All @@ -1976,7 +1983,7 @@ def custom_autoregressive_inference(self, batch, prompt_len, pred_steps=500, sid

if _t % 10 == 0:
print("Decoding timestep", _t)

# import ipdb; ipdb.set_trace()
(logits, _), _, _ = self.model(
curr_tokens,
curr_position_ids,
Expand Down Expand Up @@ -2037,8 +2044,8 @@ def custom_autoregressive_inference(self, batch, prompt_len, pred_steps=500, sid
curr_tokens = torch.cat([curr_tokens, all_speech_token_preds_processed[:,:,-1:]], dim=2)
curr_position_ids = dummy_position_ids[:,:prompt_len+_t+1]
if curr_attention_mask is not None:
curr_attention_mask = batch['attention_mask'][:,:,:prompt_len+_t+1,:prompt_len+_t+1]
curr_speech_mask = batch['speech_mask'][:,:prompt_len+_t+1]
curr_attention_mask = dummy_attention_mask[:,:,:prompt_len+_t+1,:prompt_len+_t+1]
curr_speech_mask = batch['speech_mask'][sidx:sidx+1,:prompt_len+_t+1]

all_preds = torch.stack(all_preds, dim=0) # (T, B, 8)
all_preds = all_preds.permute(1, 2, 0) # (B, 8, T)
Expand Down

0 comments on commit 746a19d

Please sign in to comment.