Skip to content

Commit

Permalink
Proposed WAR for gpt3 eval hang with PP (NVIDIA#7927)
Browse files Browse the repository at this point in the history
Signed-off-by: yaoyu-33 <[email protected]>
Co-authored-by: Eric Harper <[email protected]>
  • Loading branch information
yaoyu-33 and ericharper authored Dec 1, 2023
1 parent 44c7928 commit 93b2c7f
Showing 1 changed file with 2 additions and 6 deletions.
8 changes: 2 additions & 6 deletions nemo/collections/nlp/modules/common/text_generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,12 +507,8 @@ def synced_generate(

if compute_logprob:
precision = model._trainer.precision
if precision in [16, "16"]:
dtype = torch.float16
elif precision in ['bf16', 'bf16-mixed']:
dtype = torch.bfloat16
else:
dtype = torch.float32
dtype = torch.float32

output_logits = torch.empty(
tokens.size(0), context_length - 1, dtype=dtype, device=torch.device("cuda")
)
Expand Down

0 comments on commit 93b2c7f

Please sign in to comment.