From 123238eb42604c9f27e0038dca09b3bfe6c1acd8 Mon Sep 17 00:00:00 2001 From: behnam Date: Wed, 23 Sep 2020 17:04:37 -0700 Subject: [PATCH] Add token_type_ids to prepare_inputs_for_generation for gpt/gpt2 --- src/transformers/modeling_gpt2.py | 4 ++++ src/transformers/modeling_openai.py | 15 +++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/src/transformers/modeling_gpt2.py b/src/transformers/modeling_gpt2.py index f1671c69cb23..219f5eb71c92 100644 --- a/src/transformers/modeling_gpt2.py +++ b/src/transformers/modeling_gpt2.py @@ -680,11 +680,15 @@ def get_output_embeddings(self): def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): # only last token for inputs_ids if past is defined in kwargs + token_type_ids = kwargs.get("token_type_ids") if past: input_ids = input_ids[:, -1].unsqueeze(-1) + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -1].unsqueeze(-1) return { "input_ids": input_ids, + "token_type_ids": token_type_ids, "past_key_values": past, "use_cache": kwargs.get("use_cache"), } diff --git a/src/transformers/modeling_openai.py b/src/transformers/modeling_openai.py index a3029c6ca256..e56c0d885a07 100644 --- a/src/transformers/modeling_openai.py +++ b/src/transformers/modeling_openai.py @@ -541,6 +541,21 @@ def __init__(self, config): def get_output_embeddings(self): return self.lm_head + def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): + # only last token for inputs_ids if past is defined in kwargs + token_type_ids = kwargs.get("token_type_ids") + if past: + input_ids = input_ids[:, -1].unsqueeze(-1) + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + + return { + "input_ids": input_ids, + "token_type_ids": token_type_ids, + "past_key_values": past, + "use_cache": kwargs.get("use_cache"), + } + @add_start_docstrings_to_callable(OPENAI_GPT_INPUTS_DOCSTRING) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC,