Skip to content

Commit 9353dda

Browse files
committed
Improve text generation quality for bf16 models when sampling
1 parent a51475f commit 9353dda

File tree

1 file changed

+2
-3
lines changed
  • optimum/habana/transformers/generation

1 file changed

+2
-3
lines changed

optimum/habana/transformers/generation/utils.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -2460,7 +2460,7 @@ def _sample(
24602460
if token_idx is not None and outputs.logits.shape[-2] > 1:
24612461
# case1 (w/o KV caching): outputs.logits.shape: [batch_size, max_length, vocab_size]
24622462
if self.config.is_encoder_decoder:
2463-
next_token_logits = outputs.logits[:, token_idx - 1, :].float()
2463+
next_token_logits = outputs.logits[:, token_idx - 1, :]
24642464
next_token_scores = logits_processor(input_ids[:, :token_idx], next_token_logits)
24652465
else:
24662466
if model_kwargs.get("num_virtual_tokens", 0) > 0:
@@ -2474,8 +2474,7 @@ def _sample(
24742474
next_token_logits = torch.index_select(outputs.logits, -2, token_idx - 1).squeeze(-2)
24752475
next_token_scores = logits_processor(input_ids, next_token_logits)
24762476
else:
2477-
# .float() is needed to retain precision for later logits manipulations
2478-
next_token_logits = outputs.logits[:, -1, :].float()
2477+
next_token_logits = outputs.logits[:, -1, :]
24792478
if token_idx is not None and self.config.is_encoder_decoder:
24802479
# case2 (with KV caching): outputs.logits.shape: [batch_size, 1, vocab_size]
24812480
next_token_scores = logits_processor(input_ids[:, :token_idx], next_token_logits)

0 commit comments

Comments
 (0)