diff --git a/scripts/gpt2-tf2/gpt2_train_distributed.py b/scripts/gpt2-tf2/gpt2_train_distributed.py index 771b7c4fad4f..ddb715488557 100644 --- a/scripts/gpt2-tf2/gpt2_train_distributed.py +++ b/scripts/gpt2-tf2/gpt2_train_distributed.py @@ -50,7 +50,7 @@ def get_dataset(fil): def tokenize(data, tokenizer, truncate=False): if truncate: - data = tokenizer(data[:1000], return_tensors='tf', padding=True, truncation=True) + data = tokenizer(data[:100], return_tensors='tf', padding=True, truncation=True) else: data = tokenizer(data, return_tensors='tf', padding=True, truncation=True) return tf.data.Dataset.from_tensor_slices((dict(data), data['input_ids']))