diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index 4b2ab96842..5355ceb1b6 100755 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -457,12 +457,18 @@ def generate(size=None, reduce_recompile=False): encode_t0 = time.perf_counter() # Tokenization if args.max_input_tokens > 0: + if hasattr(model.config, "type_vocab_size") and model.config.type_vocab_size > 0: + return_token_type_ids = True + else: + return_token_type_ids = False + input_tokens = tokenizer.batch_encode_plus( input_sentences, return_tensors="pt", padding="max_length", max_length=args.max_input_tokens, truncation=True, + return_token_type_ids=return_token_type_ids, ) def compute_valid_sequence_lengths_tensor(input_tokens):