From 3ea61557dad38570c04edc5342bd260a693cbee0 Mon Sep 17 00:00:00 2001 From: LSinev Date: Thu, 14 Jan 2021 19:48:20 +0300 Subject: [PATCH] make RepetitionPenaltyLogitsProcessor faster --- src/transformers/generation_logits_process.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/transformers/generation_logits_process.py b/src/transformers/generation_logits_process.py index 166cd5aa8cc2..a027eacbde45 100644 --- a/src/transformers/generation_logits_process.py +++ b/src/transformers/generation_logits_process.py @@ -155,13 +155,12 @@ def __init__(self, penalty: float): self.penalty = penalty def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - ranges = torch.arange(scores.shape[0]) - score = scores[ranges[:, None], input_ids] + score = torch.gather(scores, 1, input_ids) # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability score = torch.where(score < 0, score * self.penalty, score / self.penalty) - scores[ranges[:, None], input_ids] = score + scores.scatter_(1, input_ids, score) return scores