Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ def collate_fn(examples):
predictions, references = accelerator.gather((predictions, batch["labels"]))
# If we are in a multiprocess environment, the last batch has duplicates
if accelerator.num_processes > 1:
if step == len(eval_dataloader):
if step == len(eval_dataloader) - 1:
predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
references = references[: len(eval_dataloader.dataset) - samples_seen]
else:
Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch/multiple-choice/run_swag_no_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ def preprocess_function(examples):
predictions, references = accelerator.gather((predictions, batch["labels"]))
# If we are in a multiprocess environment, the last batch has duplicates
if accelerator.num_processes > 1:
if step == len(eval_dataloader):
if step == len(eval_dataloader) - 1:
predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
references = references[: len(eval_dataloader.dataset) - samples_seen]
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ def preprocess_val(example_batch):

# If we are in a multiprocess environment, the last batch has duplicates
if accelerator.num_processes > 1:
if step == len(eval_dataloader):
if step == len(eval_dataloader) - 1:
predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
references = references[: len(eval_dataloader.dataset) - samples_seen]
else:
Expand Down
11 changes: 5 additions & 6 deletions examples/pytorch/summarization/run_summarization_no_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,9 @@ def parse_args():

def main():
args = parse_args()

# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
# If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
accelerator = Accelerator(log_with="all", logging_dir=args.output_dir) if args.with_tracking else Accelerator()
if args.source_prefix is None and args.model_name_or_path in [
"t5-small",
"t5-base",
Expand All @@ -322,9 +324,6 @@ def main():
"You're running a t5 model but didn't provide a source prefix, which is the expected, e.g. with "
"`--source_prefix 'summarize: ' `"
)
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
# If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
accelerator = Accelerator(log_with="all", logging_dir=args.output_dir) if args.with_tracking else Accelerator()
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
Expand Down Expand Up @@ -675,11 +674,11 @@ def postprocess_text(preds, labels):
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
# If we are in a multiprocess environment, the last batch has duplicates
if accelerator.num_processes > 1:
if step == len(eval_dataloader):
if step == len(eval_dataloader) - 1:
decoded_preds = decoded_preds[: len(eval_dataloader.dataset) - samples_seen]
decoded_labels = decoded_labels[: len(eval_dataloader.dataset) - samples_seen]
else:
samples_seen += decoded_labels.shape[0]
samples_seen += len(decoded_labels)

metric.add_batch(
predictions=decoded_preds,
Expand Down
187 changes: 79 additions & 108 deletions examples/pytorch/test_accelerate_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,49 +18,18 @@
import json
import logging
import os
import shutil
import subprocess
import sys
from unittest.mock import patch
import tempfile

import torch

from accelerate.utils import write_basic_config
from transformers.testing_utils import TestCasePlus, get_gpu_count, slow, torch_device
from transformers.utils import is_apex_available


SRC_DIRS = [
os.path.join(os.path.dirname(__file__), dirname)
for dirname in [
"text-generation",
"text-classification",
"token-classification",
"language-modeling",
"multiple-choice",
"question-answering",
"summarization",
"translation",
"image-classification",
"speech-recognition",
"audio-classification",
"speech-pretraining",
"image-pretraining",
"semantic-segmentation",
]
]
sys.path.extend(SRC_DIRS)


if SRC_DIRS is not None:
import run_clm_no_trainer
import run_glue_no_trainer
import run_image_classification_no_trainer
import run_mlm_no_trainer
import run_ner_no_trainer
import run_qa_no_trainer as run_squad_no_trainer
import run_semantic_segmentation_no_trainer
import run_summarization_no_trainer
import run_swag_no_trainer
import run_translation_no_trainer

logging.basicConfig(level=logging.DEBUG)

logger = logging.getLogger()
Expand Down Expand Up @@ -94,10 +63,22 @@ def is_cuda_and_apex_available():


class ExamplesTestsNoTrainer(TestCasePlus):
@classmethod
def setUpClass(cls):
# Write Accelerate config, will pick up on CPU, GPU, and multi-GPU
cls.tmpdir = tempfile.mkdtemp()
cls.configPath = os.path.join(cls.tmpdir, "default_config.yml")
write_basic_config(save_location=cls.configPath)
cls._launch_args = ["accelerate", "launch", "--config_file", cls.configPath]

@classmethod
def tearDownClass(cls):
shutil.rmtree(cls.tmpdir)

def test_run_glue_no_trainer(self):
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_glue_no_trainer.py
{self.examples_dir}/pytorch/text-classification/run_glue_no_trainer.py
--model_name_or_path distilbert-base-uncased
--output_dir {tmp_dir}
--train_file ./tests/fixtures/tests_samples/MRPC/train.csv
Expand All @@ -113,17 +94,16 @@ def test_run_glue_no_trainer(self):
if is_cuda_and_apex_available():
testargs.append("--fp16")

with patch.object(sys, "argv", testargs):
run_glue_no_trainer.main()
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_accuracy"], 0.75)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "glue_no_trainer")))
_ = subprocess.run(self._launch_args + testargs, stdout=subprocess.PIPE)
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_accuracy"], 0.75)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "glue_no_trainer")))

def test_run_clm_no_trainer(self):
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_clm_no_trainer.py
{self.examples_dir}/pytorch/language-modeling/run_clm_no_trainer.py
--model_name_or_path distilgpt2
--train_file ./tests/fixtures/sample_text.txt
--validation_file ./tests/fixtures/sample_text.txt
Expand All @@ -140,17 +120,16 @@ def test_run_clm_no_trainer(self):
# Skipping because there are not enough batches to train the model + would need a drop_last to work.
return

with patch.object(sys, "argv", testargs):
run_clm_no_trainer.main()
result = get_results(tmp_dir)
self.assertLess(result["perplexity"], 100)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "clm_no_trainer")))
_ = subprocess.run(self._launch_args + testargs, stdout=subprocess.PIPE)
result = get_results(tmp_dir)
self.assertLess(result["perplexity"], 100)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "clm_no_trainer")))

def test_run_mlm_no_trainer(self):
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_mlm_no_trainer.py
{self.examples_dir}/pytorch/language-modeling/run_mlm_no_trainer.py
--model_name_or_path distilroberta-base
--train_file ./tests/fixtures/sample_text.txt
--validation_file ./tests/fixtures/sample_text.txt
Expand All @@ -160,20 +139,19 @@ def test_run_mlm_no_trainer(self):
--with_tracking
""".split()

with patch.object(sys, "argv", testargs):
run_mlm_no_trainer.main()
result = get_results(tmp_dir)
self.assertLess(result["perplexity"], 42)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "mlm_no_trainer")))
_ = subprocess.run(self._launch_args + testargs, stdout=subprocess.PIPE)
result = get_results(tmp_dir)
self.assertLess(result["perplexity"], 42)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "mlm_no_trainer")))

def test_run_ner_no_trainer(self):
# with so little data distributed training needs more epochs to get the score on par with 0/1 gpu
epochs = 7 if get_gpu_count() > 1 else 2

tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_ner_no_trainer.py
{self.examples_dir}/pytorch/token-classification/run_ner_no_trainer.py
--model_name_or_path bert-base-uncased
--train_file tests/fixtures/tests_samples/conll/sample.json
--validation_file tests/fixtures/tests_samples/conll/sample.json
Expand All @@ -187,18 +165,17 @@ def test_run_ner_no_trainer(self):
--with_tracking
""".split()

with patch.object(sys, "argv", testargs):
run_ner_no_trainer.main()
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_accuracy"], 0.75)
self.assertLess(result["train_loss"], 0.5)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "ner_no_trainer")))
_ = subprocess.run(self._launch_args + testargs, stdout=subprocess.PIPE)
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_accuracy"], 0.75)
self.assertLess(result["train_loss"], 0.5)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "ner_no_trainer")))

def test_run_squad_no_trainer(self):
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_qa_no_trainer.py
{self.examples_dir}/pytorch/question-answering/run_qa_no_trainer.py
--model_name_or_path bert-base-uncased
--version_2_with_negative
--train_file tests/fixtures/tests_samples/SQUAD/sample.json
Expand All @@ -213,19 +190,18 @@ def test_run_squad_no_trainer(self):
--with_tracking
""".split()

with patch.object(sys, "argv", testargs):
run_squad_no_trainer.main()
result = get_results(tmp_dir)
# Because we use --version_2_with_negative the testing script uses SQuAD v2 metrics.
self.assertGreaterEqual(result["eval_f1"], 30)
self.assertGreaterEqual(result["eval_exact"], 30)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "qa_no_trainer")))
_ = subprocess.run(self._launch_args + testargs, stdout=subprocess.PIPE)
result = get_results(tmp_dir)
# Because we use --version_2_with_negative the testing script uses SQuAD v2 metrics.
self.assertGreaterEqual(result["eval_f1"], 30)
self.assertGreaterEqual(result["eval_exact"], 30)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "qa_no_trainer")))

def test_run_swag_no_trainer(self):
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_swag_no_trainer.py
{self.examples_dir}/pytorch/multiple-choice/run_swag_no_trainer.py
--model_name_or_path bert-base-uncased
--train_file tests/fixtures/tests_samples/swag/sample.json
--validation_file tests/fixtures/tests_samples/swag/sample.json
Expand All @@ -238,17 +214,16 @@ def test_run_swag_no_trainer(self):
--with_tracking
""".split()

with patch.object(sys, "argv", testargs):
run_swag_no_trainer.main()
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_accuracy"], 0.8)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "swag_no_trainer")))
_ = subprocess.run(self._launch_args + testargs, stdout=subprocess.PIPE)
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_accuracy"], 0.8)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "swag_no_trainer")))

@slow
def test_run_summarization_no_trainer(self):
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_summarization_no_trainer.py
{self.examples_dir}/pytorch/summarization/run_summarization_no_trainer.py
--model_name_or_path t5-small
--train_file tests/fixtures/tests_samples/xsum/sample.json
--validation_file tests/fixtures/tests_samples/xsum/sample.json
Expand All @@ -262,21 +237,20 @@ def test_run_summarization_no_trainer(self):
--with_tracking
""".split()

with patch.object(sys, "argv", testargs):
run_summarization_no_trainer.main()
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_rouge1"], 10)
self.assertGreaterEqual(result["eval_rouge2"], 2)
self.assertGreaterEqual(result["eval_rougeL"], 7)
self.assertGreaterEqual(result["eval_rougeLsum"], 7)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "summarization_no_trainer")))
_ = subprocess.run(self._launch_args + testargs, stdout=subprocess.PIPE)
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_rouge1"], 10)
self.assertGreaterEqual(result["eval_rouge2"], 2)
self.assertGreaterEqual(result["eval_rougeL"], 7)
self.assertGreaterEqual(result["eval_rougeLsum"], 7)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "summarization_no_trainer")))

@slow
def test_run_translation_no_trainer(self):
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_translation_no_trainer.py
{self.examples_dir}/pytorch/translation/run_translation_no_trainer.py
--model_name_or_path sshleifer/student_marian_en_ro_6_1
--source_lang en
--target_lang ro
Expand All @@ -294,12 +268,11 @@ def test_run_translation_no_trainer(self):
--with_tracking
""".split()

with patch.object(sys, "argv", testargs):
run_translation_no_trainer.main()
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_bleu"], 30)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "translation_no_trainer")))
_ = subprocess.run(self._launch_args + testargs, stdout=subprocess.PIPE)
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_bleu"], 30)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "translation_no_trainer")))

@slow
def test_run_semantic_segmentation_no_trainer(self):
Expand All @@ -308,7 +281,7 @@ def test_run_semantic_segmentation_no_trainer(self):

tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_semantic_segmentation_no_trainer.py
{self.examples_dir}/pytorch/semantic-segmentation/run_semantic_segmentation_no_trainer.py
--dataset_name huggingface/semantic-segmentation-test-sample
--output_dir {tmp_dir}
--max_train_steps=10
Expand All @@ -319,15 +292,14 @@ def test_run_semantic_segmentation_no_trainer(self):
--checkpointing_steps epoch
""".split()

with patch.object(sys, "argv", testargs):
run_semantic_segmentation_no_trainer.main()
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_overall_accuracy"], 0.10)
_ = subprocess.run(self._launch_args + testargs, stdout=subprocess.PIPE)
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_overall_accuracy"], 0.10)

def test_run_image_classification_no_trainer(self):
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_image_classification_no_trainer.py
{self.examples_dir}/pytorch/image-classification/run_image_classification_no_trainer.py
--dataset_name huggingface/image-classification-test-sample
--output_dir {tmp_dir}
--num_warmup_steps=8
Expand All @@ -339,9 +311,8 @@ def test_run_image_classification_no_trainer(self):
--seed 42
""".split()

with patch.object(sys, "argv", testargs):
run_image_classification_no_trainer.main()
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_accuracy"], 0.50)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "image_classification_no_trainer")))
_ = subprocess.run(self._launch_args + testargs, stdout=subprocess.PIPE)
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_accuracy"], 0.50)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "image_classification_no_trainer")))
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ def preprocess_function(examples):
predictions, references = accelerator.gather((predictions, batch["labels"]))
# If we are in a multiprocess environment, the last batch has duplicates
if accelerator.num_processes > 1:
if step == len(eval_dataloader):
if step == len(eval_dataloader) - 1:
predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
references = references[: len(eval_dataloader.dataset) - samples_seen]
else:
Expand Down
Loading