Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions examples/image-classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ python run_image_classification.py \
--dataset_name cifar10 \
--output_dir /tmp/outputs/ \
--remove_unused_columns False \
--image_column_name img \
--do_train \
--do_eval \
--learning_rate 3e-5 \
Expand Down Expand Up @@ -182,6 +183,7 @@ python ../gaudi_spawn.py \
--dataset_name cifar10 \
--output_dir /tmp/outputs/ \
--remove_unused_columns False \
--image_column_name img \
--do_train \
--do_eval \
--learning_rate 2e-4 \
Expand Down Expand Up @@ -221,6 +223,7 @@ python ../gaudi_spawn.py \
--dataset_name cifar10 \
--output_dir /tmp/outputs/ \
--remove_unused_columns False \
--image_column_name img \
--do_train \
--do_eval \
--learning_rate 2e-4 \
Expand Down Expand Up @@ -276,6 +279,7 @@ python run_image_classification.py \
--dataset_name cifar10 \
--output_dir /tmp/outputs/ \
--remove_unused_columns False \
--image_column_name img \
--do_eval \
--per_device_eval_batch_size 64 \
--use_habana \
Expand Down
43 changes: 32 additions & 11 deletions examples/image-classification/run_image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,14 @@ class DataTrainingArguments:
)
},
)
image_column_name: str = field(
default="image",
metadata={"help": "The name of the dataset column containing the image data. Defaults to 'image'."},
)
label_column_name: str = field(
default="label",
metadata={"help": "The name of the dataset column containing the labels. Defaults to 'label'."},
)

def __post_init__(self):
if self.dataset_name is None and (self.train_dir is None and self.validation_dir is None):
Expand Down Expand Up @@ -183,12 +191,6 @@ class ModelArguments:
)


def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"] for example in examples])
labels = torch.tensor([example["labels"] for example in examples])
return {"pixel_values": pixel_values, "labels": labels}


def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
Expand Down Expand Up @@ -272,7 +274,6 @@ def main():
data_args.dataset_name,
data_args.dataset_config_name,
cache_dir=model_args.cache_dir,
task="image-classification",
token=model_args.token,
)
else:
Expand All @@ -285,9 +286,27 @@ def main():
"imagefolder",
data_files=data_files,
cache_dir=model_args.cache_dir,
task="image-classification",
)

dataset_column_names = dataset["train"].column_names if "train" in dataset else dataset["validation"].column_names
if data_args.image_column_name not in dataset_column_names:
raise ValueError(
f"--image_column_name {data_args.image_column_name} not found in dataset '{data_args.dataset_name}'. "
"Make sure to set `--image_column_name` to the correct audio column - one of "
f"{', '.join(dataset_column_names)}."
)
if data_args.label_column_name not in dataset_column_names:
raise ValueError(
f"--label_column_name {data_args.label_column_name} not found in dataset '{data_args.dataset_name}'. "
"Make sure to set `--label_column_name` to the correct text column - one of "
f"{', '.join(dataset_column_names)}."
)

def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"] for example in examples])
labels = torch.tensor([example[data_args.label_column_name] for example in examples])
return {"pixel_values": pixel_values, "labels": labels}

# If we don't have a validation split, split off a percentage of train as validation.
data_args.train_val_split = None if "validation" in dataset.keys() else data_args.train_val_split
if isinstance(data_args.train_val_split, float) and data_args.train_val_split > 0.0:
Expand All @@ -297,7 +316,7 @@ def main():

# Prepare label mappings.
# We'll include these in the model's config to get human readable labels in the Inference API.
labels = dataset["train"].features["labels"].names
labels = dataset["train"].features[data_args.label_column_name].names
label2id, id2label = {}, {}
for i, label in enumerate(labels):
label2id[label] = str(i)
Expand Down Expand Up @@ -371,13 +390,15 @@ def compute_metrics(p):
def train_transforms(example_batch):
"""Apply _train_transforms across a batch."""
example_batch["pixel_values"] = [
_train_transforms(pil_img.convert("RGB")) for pil_img in example_batch["image"]
_train_transforms(pil_img.convert("RGB")) for pil_img in example_batch[data_args.image_column_name]
]
return example_batch

def val_transforms(example_batch):
"""Apply _val_transforms across a batch."""
example_batch["pixel_values"] = [_val_transforms(pil_img.convert("RGB")) for pil_img in example_batch["image"]]
example_batch["pixel_values"] = [
_val_transforms(pil_img.convert("RGB")) for pil_img in example_batch[data_args.image_column_name]
]
return example_batch

if training_args.do_train:
Expand Down
4 changes: 4 additions & 0 deletions tests/baselines/swin_base_patch4_window7_224_in22k.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"train_samples_per_second": 203.619,
"extra_arguments": [
"--remove_unused_columns False",
"--image_column_name img",
"--seed 1337",
"--use_hpu_graphs_for_inference",
"--ignore_mismatched_sizes",
Expand All @@ -28,6 +29,7 @@
"train_samples_per_second": 1679.61,
"extra_arguments": [
"--remove_unused_columns False",
"--image_column_name img",
"--seed 1337",
"--use_hpu_graphs_for_inference",
"--ignore_mismatched_sizes",
Expand All @@ -52,6 +54,7 @@
"train_samples_per_second": 840.673,
"extra_arguments": [
"--remove_unused_columns False",
"--image_column_name img",
"--seed 1337",
"--use_hpu_graphs_for_inference",
"--ignore_mismatched_sizes",
Expand All @@ -68,6 +71,7 @@
"train_samples_per_second": 5820.915,
"extra_arguments": [
"--remove_unused_columns False",
"--image_column_name img",
"--seed 1337",
"--use_hpu_graphs_for_inference",
"--ignore_mismatched_sizes",
Expand Down
4 changes: 4 additions & 0 deletions tests/baselines/vit_base_patch16_224_in21k.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"train_samples_per_second": 349.875,
"extra_arguments": [
"--remove_unused_columns False",
"--image_column_name img",
"--seed 1337",
"--use_hpu_graphs_for_inference",
"--dataloader_num_workers 1",
Expand All @@ -27,6 +28,7 @@
"train_samples_per_second": 2509.027,
"extra_arguments": [
"--remove_unused_columns False",
"--image_column_name img",
"--seed 1337",
"--use_hpu_graphs_for_inference",
"--dataloader_num_workers 1",
Expand All @@ -51,6 +53,7 @@
"train_samples_per_second": 904.475,
"extra_arguments": [
"--remove_unused_columns False",
"--image_column_name img",
"--seed 1337",
"--use_hpu_graphs_for_inference",
"--dataloader_num_workers 1",
Expand All @@ -66,6 +69,7 @@
"train_samples_per_second": 4251.991,
"extra_arguments": [
"--remove_unused_columns False",
"--image_column_name img",
"--seed 1337",
"--use_hpu_graphs_for_inference",
"--dataloader_num_workers 1",
Expand Down