From 503e85350120f02d3b6e0a5e43ae42273864044a Mon Sep 17 00:00:00 2001 From: regisss <15324346+regisss@users.noreply.github.com> Date: Tue, 9 Jan 2024 15:09:32 +0100 Subject: [PATCH 1/2] Fix error in run_image_classification.py --- examples/image-classification/run_image_classification.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/examples/image-classification/run_image_classification.py b/examples/image-classification/run_image_classification.py index 2052117743..d0d5773427 100644 --- a/examples/image-classification/run_image_classification.py +++ b/examples/image-classification/run_image_classification.py @@ -272,7 +272,6 @@ def main(): data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, - task="image-classification", token=model_args.token, ) else: @@ -285,9 +284,14 @@ def main(): "imagefolder", data_files=data_files, cache_dir=model_args.cache_dir, - task="image-classification", ) + # Rename image and label columns if needed (e.g. Cifar10) + if "img" in (dataset["train"].features if "train" in dataset else dataset["validation"].features): + dataset = dataset.rename_column("img", "image") + if "label" in (dataset["train"].features if "train" in dataset else dataset["validation"].features): + dataset = dataset.rename_column("label", "labels") + # If we don't have a validation split, split off a percentage of train as validation. data_args.train_val_split = None if "validation" in dataset.keys() else data_args.train_val_split if isinstance(data_args.train_val_split, float) and data_args.train_val_split > 0.0: From 024b22fd601eb8e219134d57ee8048adf942f7ad Mon Sep 17 00:00:00 2001 From: regisss <15324346+regisss@users.noreply.github.com> Date: Thu, 11 Jan 2024 14:59:07 +0000 Subject: [PATCH 2/2] Fix image-classification example --- examples/image-classification/README.md | 4 ++ .../run_image_classification.py | 45 +++++++++++++------ .../swin_base_patch4_window7_224_in22k.json | 4 ++ .../baselines/vit_base_patch16_224_in21k.json | 4 ++ 4 files changed, 43 insertions(+), 14 deletions(-) diff --git a/examples/image-classification/README.md b/examples/image-classification/README.md index 088d1f66af..a95b5bd7bb 100644 --- a/examples/image-classification/README.md +++ b/examples/image-classification/README.md @@ -31,6 +31,7 @@ python run_image_classification.py \ --dataset_name cifar10 \ --output_dir /tmp/outputs/ \ --remove_unused_columns False \ + --image_column_name img \ --do_train \ --do_eval \ --learning_rate 3e-5 \ @@ -182,6 +183,7 @@ python ../gaudi_spawn.py \ --dataset_name cifar10 \ --output_dir /tmp/outputs/ \ --remove_unused_columns False \ + --image_column_name img \ --do_train \ --do_eval \ --learning_rate 2e-4 \ @@ -221,6 +223,7 @@ python ../gaudi_spawn.py \ --dataset_name cifar10 \ --output_dir /tmp/outputs/ \ --remove_unused_columns False \ + --image_column_name img \ --do_train \ --do_eval \ --learning_rate 2e-4 \ @@ -276,6 +279,7 @@ python run_image_classification.py \ --dataset_name cifar10 \ --output_dir /tmp/outputs/ \ --remove_unused_columns False \ + --image_column_name img \ --do_eval \ --per_device_eval_batch_size 64 \ --use_habana \ diff --git a/examples/image-classification/run_image_classification.py b/examples/image-classification/run_image_classification.py index d0d5773427..547e6499e0 100644 --- a/examples/image-classification/run_image_classification.py +++ b/examples/image-classification/run_image_classification.py @@ -119,6 +119,14 @@ class DataTrainingArguments: ) }, ) + image_column_name: str = field( + default="image", + metadata={"help": "The name of the dataset column containing the image data. Defaults to 'image'."}, + ) + label_column_name: str = field( + default="label", + metadata={"help": "The name of the dataset column containing the labels. Defaults to 'label'."}, + ) def __post_init__(self): if self.dataset_name is None and (self.train_dir is None and self.validation_dir is None): @@ -183,12 +191,6 @@ class ModelArguments: ) -def collate_fn(examples): - pixel_values = torch.stack([example["pixel_values"] for example in examples]) - labels = torch.tensor([example["labels"] for example in examples]) - return {"pixel_values": pixel_values, "labels": labels} - - def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. @@ -286,11 +288,24 @@ def main(): cache_dir=model_args.cache_dir, ) - # Rename image and label columns if needed (e.g. Cifar10) - if "img" in (dataset["train"].features if "train" in dataset else dataset["validation"].features): - dataset = dataset.rename_column("img", "image") - if "label" in (dataset["train"].features if "train" in dataset else dataset["validation"].features): - dataset = dataset.rename_column("label", "labels") + dataset_column_names = dataset["train"].column_names if "train" in dataset else dataset["validation"].column_names + if data_args.image_column_name not in dataset_column_names: + raise ValueError( + f"--image_column_name {data_args.image_column_name} not found in dataset '{data_args.dataset_name}'. " + "Make sure to set `--image_column_name` to the correct audio column - one of " + f"{', '.join(dataset_column_names)}." + ) + if data_args.label_column_name not in dataset_column_names: + raise ValueError( + f"--label_column_name {data_args.label_column_name} not found in dataset '{data_args.dataset_name}'. " + "Make sure to set `--label_column_name` to the correct text column - one of " + f"{', '.join(dataset_column_names)}." + ) + + def collate_fn(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + labels = torch.tensor([example[data_args.label_column_name] for example in examples]) + return {"pixel_values": pixel_values, "labels": labels} # If we don't have a validation split, split off a percentage of train as validation. data_args.train_val_split = None if "validation" in dataset.keys() else data_args.train_val_split @@ -301,7 +316,7 @@ def main(): # Prepare label mappings. # We'll include these in the model's config to get human readable labels in the Inference API. - labels = dataset["train"].features["labels"].names + labels = dataset["train"].features[data_args.label_column_name].names label2id, id2label = {}, {} for i, label in enumerate(labels): label2id[label] = str(i) @@ -375,13 +390,15 @@ def compute_metrics(p): def train_transforms(example_batch): """Apply _train_transforms across a batch.""" example_batch["pixel_values"] = [ - _train_transforms(pil_img.convert("RGB")) for pil_img in example_batch["image"] + _train_transforms(pil_img.convert("RGB")) for pil_img in example_batch[data_args.image_column_name] ] return example_batch def val_transforms(example_batch): """Apply _val_transforms across a batch.""" - example_batch["pixel_values"] = [_val_transforms(pil_img.convert("RGB")) for pil_img in example_batch["image"]] + example_batch["pixel_values"] = [ + _val_transforms(pil_img.convert("RGB")) for pil_img in example_batch[data_args.image_column_name] + ] return example_batch if training_args.do_train: diff --git a/tests/baselines/swin_base_patch4_window7_224_in22k.json b/tests/baselines/swin_base_patch4_window7_224_in22k.json index 053258d3f0..94b9934137 100644 --- a/tests/baselines/swin_base_patch4_window7_224_in22k.json +++ b/tests/baselines/swin_base_patch4_window7_224_in22k.json @@ -12,6 +12,7 @@ "train_samples_per_second": 203.619, "extra_arguments": [ "--remove_unused_columns False", + "--image_column_name img", "--seed 1337", "--use_hpu_graphs_for_inference", "--ignore_mismatched_sizes", @@ -28,6 +29,7 @@ "train_samples_per_second": 1679.61, "extra_arguments": [ "--remove_unused_columns False", + "--image_column_name img", "--seed 1337", "--use_hpu_graphs_for_inference", "--ignore_mismatched_sizes", @@ -52,6 +54,7 @@ "train_samples_per_second": 840.673, "extra_arguments": [ "--remove_unused_columns False", + "--image_column_name img", "--seed 1337", "--use_hpu_graphs_for_inference", "--ignore_mismatched_sizes", @@ -68,6 +71,7 @@ "train_samples_per_second": 5820.915, "extra_arguments": [ "--remove_unused_columns False", + "--image_column_name img", "--seed 1337", "--use_hpu_graphs_for_inference", "--ignore_mismatched_sizes", diff --git a/tests/baselines/vit_base_patch16_224_in21k.json b/tests/baselines/vit_base_patch16_224_in21k.json index 8b53e3f89b..96c945bfc0 100644 --- a/tests/baselines/vit_base_patch16_224_in21k.json +++ b/tests/baselines/vit_base_patch16_224_in21k.json @@ -12,6 +12,7 @@ "train_samples_per_second": 349.875, "extra_arguments": [ "--remove_unused_columns False", + "--image_column_name img", "--seed 1337", "--use_hpu_graphs_for_inference", "--dataloader_num_workers 1", @@ -27,6 +28,7 @@ "train_samples_per_second": 2509.027, "extra_arguments": [ "--remove_unused_columns False", + "--image_column_name img", "--seed 1337", "--use_hpu_graphs_for_inference", "--dataloader_num_workers 1", @@ -51,6 +53,7 @@ "train_samples_per_second": 904.475, "extra_arguments": [ "--remove_unused_columns False", + "--image_column_name img", "--seed 1337", "--use_hpu_graphs_for_inference", "--dataloader_num_workers 1", @@ -66,6 +69,7 @@ "train_samples_per_second": 4251.991, "extra_arguments": [ "--remove_unused_columns False", + "--image_column_name img", "--seed 1337", "--use_hpu_graphs_for_inference", "--dataloader_num_workers 1",