Skip to content

Conversation

@LSinev
Copy link
Contributor

@LSinev LSinev commented Jan 14, 2021

What does this PR do?

Speeds up RepetitionPenaltyLogitsProcessor using torch gather-scatter functions. Tested on pytorch 1.4.0.
Here's a minimal example to reproduce the slow behavior (and test speed of improvements):

import torch
from transformers import RepetitionPenaltyLogitsProcessor, LogitsProcessor
import timeit
import sys


class RepetitionPenaltyLogitsProcessorNew(LogitsProcessor):
    r"""
    :class:`transformers.LogitsProcessor` enforcing an exponential penalty on repeated sequences.

    Args:
        repetition_penalty (:obj:`float`):
            The parameter for repetition penalty. 1.0 means no penalty. See `this paper
            <https://arxiv.org/pdf/1909.05858.pdf>`__ for more details.
    """

    def __init__(self, penalty: float):
        if not isinstance(penalty, float) or not (penalty > 0):
            raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")

        self.penalty = penalty

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        score = torch.gather(scores, 1, input_ids)  # changed here

        # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
        score = torch.where(score < 0, score * self.penalty, score / self.penalty)

        scores.scatter_(1, input_ids, score)  # changed here
        return scores


input_ids = torch.randint(0, 10000, (256, 256))
scores = torch.randn(256, 10000)

rep_proc = RepetitionPenaltyLogitsProcessor(1.3)

rep_proc_new = RepetitionPenaltyLogitsProcessorNew(1.3)

assert torch.eq(rep_proc(input_ids, scores), rep_proc_new(input_ids, scores)).all().item(), "Should be equal"

print("Python version:", sys.version)
print("Pytorch version:", torch.__version__, "\n")

print(f"Existing rep_proc impl time for 100 iterations on CPU = {timeit.timeit(lambda: rep_proc(input_ids, scores), number=100)}")
print(f"Proposed rep_proc impl time for 100 iterations on CPU = {timeit.timeit(lambda: rep_proc_new(input_ids, scores), number=100)}\n")

if torch.cuda.is_available():
    input_ids = input_ids.cuda()
    scores = scores.cuda()

    print(f"Existing rep_proc impl time for 100 iterations on GPU = {timeit.timeit(lambda: rep_proc(input_ids, scores), number=100)}")
    print(f"Proposed rep_proc impl time for 100 iterations on GPU = {timeit.timeit(lambda: rep_proc_new(input_ids, scores), number=100)}")

Timings reported:

Python version: 3.7.9 (default, Aug 31 2020, 12:42:55) 
[GCC 7.3.0]
Pytorch version: 1.4.0 

Existing rep_proc impl time for 100 iterations on CPU = 0.0807734300001357
Proposed rep_proc impl time for 100 iterations on CPU = 0.044223628000054305

Existing rep_proc impl time for 100 iterations on GPU = 0.017542457000217837
Proposed rep_proc impl time for 100 iterations on GPU = 0.00720681400025569

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@LysandreJik, @patrickvonplaten

@LSinev LSinev mentioned this pull request Jan 16, 2021
5 tasks
Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! Tests are passing, code looks good -> good to merge IMO.

@patrickvonplaten patrickvonplaten merged commit a98173c into huggingface:master Jan 20, 2021
@LSinev LSinev deleted the feature/speedup_rep_penalty branch May 1, 2021 23:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants