Skip to content

Commit

Permalink
1. Fixed device assignemnt.
Browse files Browse the repository at this point in the history
Signed-off-by: Micha Livne <[email protected]>
  • Loading branch information
michalivne committed Mar 6, 2023
1 parent 4fe70c0 commit 0d51dfa
Showing 1 changed file with 2 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1255,11 +1255,11 @@ def dummy():
else:
predicted_log_probs = torch.zeros(
(predicted_log_probs.shape[0], predicted_log_probs.shape[1]), dtype=self.autocast_dtype
).to(device)
).cuda()
predicted_tokens_dec = torch.zeros(
(predicted_tokens_dec.shape[0], predicted_tokens_dec.shape[1] + 1),
dtype=predicted_tokens_dec.dtype,
).to(device)
).cuda()

if self.cfg.get('pipeline_model_parallel_size', 1) > 1:
# Broadcast from the last pipeline stage to all other model-parallel ranks.
Expand Down

0 comments on commit 0d51dfa

Please sign in to comment.