diff --git a/examples/pytorch/contrastive-image-text/run_clip.py b/examples/pytorch/contrastive-image-text/run_clip.py index 3bed494b75c6..8353333ef827 100644 --- a/examples/pytorch/contrastive-image-text/run_clip.py +++ b/examples/pytorch/contrastive-image-text/run_clip.py @@ -141,10 +141,6 @@ class DataTrainingArguments: default=None, metadata={"help": "An optional input evaluation data file (a jsonlines file)."}, ) - test_file: Optional[str] = field( - default=None, - metadata={"help": "An optional input testing data file (a jsonlines file)."}, - ) max_seq_length: Optional[int] = field( default=128, metadata={ @@ -190,9 +186,6 @@ def __post_init__(self): if self.validation_file is not None: extension = self.validation_file.split(".")[-1] assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." - if self.test_file is not None: - extension = self.test_file.split(".")[-1] - assert extension in ["csv", "json"], "`test_file` should be a csv or a json file." dataset_name_mapping = { @@ -315,9 +308,6 @@ def main(): if data_args.validation_file is not None: data_files["validation"] = data_args.validation_file extension = data_args.validation_file.split(".")[-1] - if data_args.test_file is not None: - data_files["test"] = data_args.test_file - extension = data_args.test_file.split(".")[-1] dataset = load_dataset( extension, data_files=data_files, @@ -387,8 +377,6 @@ def _freeze_params(module): column_names = dataset["train"].column_names elif training_args.do_eval: column_names = dataset["validation"].column_names - elif training_args.do_predict: - column_names = dataset["test"].column_names else: logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.") return @@ -490,29 +478,6 @@ def filter_corrupt_images(examples): # Transform images on the fly as doing it on the whole dataset takes too much time. eval_dataset.set_transform(transform_images) - if training_args.do_predict: - if "test" not in dataset: - raise ValueError("--do_predict requires a test dataset") - test_dataset = dataset["test"] - if data_args.max_eval_samples is not None: - max_eval_samples = min(len(test_dataset), data_args.max_eval_samples) - test_dataset = test_dataset.select(range(max_eval_samples)) - - test_dataset = test_dataset.filter( - filter_corrupt_images, batched=True, num_proc=data_args.preprocessing_num_workers - ) - test_dataset = test_dataset.map( - function=tokenize_captions, - batched=True, - num_proc=data_args.preprocessing_num_workers, - remove_columns=[col for col in column_names if col != image_column], - load_from_cache_file=not data_args.overwrite_cache, - desc="Running tokenizer on test dataset", - ) - - # Transform images on the fly as doing it on the whole dataset takes too much time. - test_dataset.set_transform(transform_images) - # 8. Initialize our trainer trainer = Trainer( model=model,