@@ -2460,7 +2460,7 @@ def _sample(
2460
2460
if token_idx is not None and outputs .logits .shape [- 2 ] > 1 :
2461
2461
# case1 (w/o KV caching): outputs.logits.shape: [batch_size, max_length, vocab_size]
2462
2462
if self .config .is_encoder_decoder :
2463
- next_token_logits = outputs .logits [:, token_idx - 1 , :]. float ()
2463
+ next_token_logits = outputs .logits [:, token_idx - 1 , :]
2464
2464
next_token_scores = logits_processor (input_ids [:, :token_idx ], next_token_logits )
2465
2465
else :
2466
2466
if model_kwargs .get ("num_virtual_tokens" , 0 ) > 0 :
@@ -2474,8 +2474,7 @@ def _sample(
2474
2474
next_token_logits = torch .index_select (outputs .logits , - 2 , token_idx - 1 ).squeeze (- 2 )
2475
2475
next_token_scores = logits_processor (input_ids , next_token_logits )
2476
2476
else :
2477
- # .float() is needed to retain precision for later logits manipulations
2478
- next_token_logits = outputs .logits [:, - 1 , :].float ()
2477
+ next_token_logits = outputs .logits [:, - 1 , :]
2479
2478
if token_idx is not None and self .config .is_encoder_decoder :
2480
2479
# case2 (with KV caching): outputs.logits.shape: [batch_size, 1, vocab_size]
2481
2480
next_token_scores = logits_processor (input_ids [:, :token_idx ], next_token_logits )
0 commit comments