Skip to content

Commit

Permalink
Added num decoder blocks in megatron export (#6331) (#6351)
Browse files Browse the repository at this point in the history
* added RPE + fixed RMSNorm



* updated megatron NMT export to add decoder blocks



* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: David Mosallanezhad <[email protected]>
Co-authored-by: David <[email protected]>
Co-authored-by: David Mosallanezhad <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
4 people committed Apr 4, 2023
1 parent 4cf20b2 commit 2b0a7a0
Showing 1 changed file with 6 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,9 @@ def forward(self, input_ids, decoder_mask, encoder_mask, encoder_embeddings, dec
.float()
)

zeros = torch.zeros((decoder_mems.shape[0], 6, dec_out.shape[1], decoder_mems.shape[-1])).to(self.device)
zeros = torch.zeros(
(decoder_mems.shape[0], self.decoder.num_layers, dec_out.shape[1], decoder_mems.shape[-1])
).to(self.device)

return torch.cat((zeros, dec_out.unsqueeze(1)), dim=1)

Expand All @@ -142,7 +144,9 @@ def input_example(self, max_batch=1, max_dim=768, seq_len=6):
dec_attn_mask = torch.tensor([[1 for _ in range(dec_len)]]).to(self.device)

# constant decoder_mems as placeholder for now
decoder_mems = torch.zeros([max_batch, 7, seq_len, max_dim], dtype=torch.float32).to(self.device)
decoder_mems = torch.zeros([max_batch, self.decoder.num_layers + 1, seq_len, max_dim], dtype=torch.float32).to(
self.device
)

# input_ids, decoder_mask, encoder_mask, encoder_embeddings

Expand Down

0 comments on commit 2b0a7a0

Please sign in to comment.