diff --git a/examples/image-classification/README.md b/examples/image-classification/README.md index 088d1f66af..a95b5bd7bb 100644 --- a/examples/image-classification/README.md +++ b/examples/image-classification/README.md @@ -31,6 +31,7 @@ python run_image_classification.py \ --dataset_name cifar10 \ --output_dir /tmp/outputs/ \ --remove_unused_columns False \ + --image_column_name img \ --do_train \ --do_eval \ --learning_rate 3e-5 \ @@ -182,6 +183,7 @@ python ../gaudi_spawn.py \ --dataset_name cifar10 \ --output_dir /tmp/outputs/ \ --remove_unused_columns False \ + --image_column_name img \ --do_train \ --do_eval \ --learning_rate 2e-4 \ @@ -221,6 +223,7 @@ python ../gaudi_spawn.py \ --dataset_name cifar10 \ --output_dir /tmp/outputs/ \ --remove_unused_columns False \ + --image_column_name img \ --do_train \ --do_eval \ --learning_rate 2e-4 \ @@ -276,6 +279,7 @@ python run_image_classification.py \ --dataset_name cifar10 \ --output_dir /tmp/outputs/ \ --remove_unused_columns False \ + --image_column_name img \ --do_eval \ --per_device_eval_batch_size 64 \ --use_habana \ diff --git a/examples/image-classification/run_image_classification.py b/examples/image-classification/run_image_classification.py index 2052117743..547e6499e0 100644 --- a/examples/image-classification/run_image_classification.py +++ b/examples/image-classification/run_image_classification.py @@ -119,6 +119,14 @@ class DataTrainingArguments: ) }, ) + image_column_name: str = field( + default="image", + metadata={"help": "The name of the dataset column containing the image data. Defaults to 'image'."}, + ) + label_column_name: str = field( + default="label", + metadata={"help": "The name of the dataset column containing the labels. Defaults to 'label'."}, + ) def __post_init__(self): if self.dataset_name is None and (self.train_dir is None and self.validation_dir is None): @@ -183,12 +191,6 @@ class ModelArguments: ) -def collate_fn(examples): - pixel_values = torch.stack([example["pixel_values"] for example in examples]) - labels = torch.tensor([example["labels"] for example in examples]) - return {"pixel_values": pixel_values, "labels": labels} - - def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. @@ -272,7 +274,6 @@ def main(): data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, - task="image-classification", token=model_args.token, ) else: @@ -285,9 +286,27 @@ def main(): "imagefolder", data_files=data_files, cache_dir=model_args.cache_dir, - task="image-classification", ) + dataset_column_names = dataset["train"].column_names if "train" in dataset else dataset["validation"].column_names + if data_args.image_column_name not in dataset_column_names: + raise ValueError( + f"--image_column_name {data_args.image_column_name} not found in dataset '{data_args.dataset_name}'. " + "Make sure to set `--image_column_name` to the correct audio column - one of " + f"{', '.join(dataset_column_names)}." + ) + if data_args.label_column_name not in dataset_column_names: + raise ValueError( + f"--label_column_name {data_args.label_column_name} not found in dataset '{data_args.dataset_name}'. " + "Make sure to set `--label_column_name` to the correct text column - one of " + f"{', '.join(dataset_column_names)}." + ) + + def collate_fn(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + labels = torch.tensor([example[data_args.label_column_name] for example in examples]) + return {"pixel_values": pixel_values, "labels": labels} + # If we don't have a validation split, split off a percentage of train as validation. data_args.train_val_split = None if "validation" in dataset.keys() else data_args.train_val_split if isinstance(data_args.train_val_split, float) and data_args.train_val_split > 0.0: @@ -297,7 +316,7 @@ def main(): # Prepare label mappings. # We'll include these in the model's config to get human readable labels in the Inference API. - labels = dataset["train"].features["labels"].names + labels = dataset["train"].features[data_args.label_column_name].names label2id, id2label = {}, {} for i, label in enumerate(labels): label2id[label] = str(i) @@ -371,13 +390,15 @@ def compute_metrics(p): def train_transforms(example_batch): """Apply _train_transforms across a batch.""" example_batch["pixel_values"] = [ - _train_transforms(pil_img.convert("RGB")) for pil_img in example_batch["image"] + _train_transforms(pil_img.convert("RGB")) for pil_img in example_batch[data_args.image_column_name] ] return example_batch def val_transforms(example_batch): """Apply _val_transforms across a batch.""" - example_batch["pixel_values"] = [_val_transforms(pil_img.convert("RGB")) for pil_img in example_batch["image"]] + example_batch["pixel_values"] = [ + _val_transforms(pil_img.convert("RGB")) for pil_img in example_batch[data_args.image_column_name] + ] return example_batch if training_args.do_train: diff --git a/tests/baselines/swin_base_patch4_window7_224_in22k.json b/tests/baselines/swin_base_patch4_window7_224_in22k.json index 053258d3f0..94b9934137 100644 --- a/tests/baselines/swin_base_patch4_window7_224_in22k.json +++ b/tests/baselines/swin_base_patch4_window7_224_in22k.json @@ -12,6 +12,7 @@ "train_samples_per_second": 203.619, "extra_arguments": [ "--remove_unused_columns False", + "--image_column_name img", "--seed 1337", "--use_hpu_graphs_for_inference", "--ignore_mismatched_sizes", @@ -28,6 +29,7 @@ "train_samples_per_second": 1679.61, "extra_arguments": [ "--remove_unused_columns False", + "--image_column_name img", "--seed 1337", "--use_hpu_graphs_for_inference", "--ignore_mismatched_sizes", @@ -52,6 +54,7 @@ "train_samples_per_second": 840.673, "extra_arguments": [ "--remove_unused_columns False", + "--image_column_name img", "--seed 1337", "--use_hpu_graphs_for_inference", "--ignore_mismatched_sizes", @@ -68,6 +71,7 @@ "train_samples_per_second": 5820.915, "extra_arguments": [ "--remove_unused_columns False", + "--image_column_name img", "--seed 1337", "--use_hpu_graphs_for_inference", "--ignore_mismatched_sizes", diff --git a/tests/baselines/vit_base_patch16_224_in21k.json b/tests/baselines/vit_base_patch16_224_in21k.json index 8b53e3f89b..96c945bfc0 100644 --- a/tests/baselines/vit_base_patch16_224_in21k.json +++ b/tests/baselines/vit_base_patch16_224_in21k.json @@ -12,6 +12,7 @@ "train_samples_per_second": 349.875, "extra_arguments": [ "--remove_unused_columns False", + "--image_column_name img", "--seed 1337", "--use_hpu_graphs_for_inference", "--dataloader_num_workers 1", @@ -27,6 +28,7 @@ "train_samples_per_second": 2509.027, "extra_arguments": [ "--remove_unused_columns False", + "--image_column_name img", "--seed 1337", "--use_hpu_graphs_for_inference", "--dataloader_num_workers 1", @@ -51,6 +53,7 @@ "train_samples_per_second": 904.475, "extra_arguments": [ "--remove_unused_columns False", + "--image_column_name img", "--seed 1337", "--use_hpu_graphs_for_inference", "--dataloader_num_workers 1", @@ -66,6 +69,7 @@ "train_samples_per_second": 4251.991, "extra_arguments": [ "--remove_unused_columns False", + "--image_column_name img", "--seed 1337", "--use_hpu_graphs_for_inference", "--dataloader_num_workers 1",