Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/seq2seq/finetune_trainer.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#!/usr/bin/env python

import logging
import os
import sys
Expand Down
3 changes: 2 additions & 1 deletion examples/seq2seq/seq2seq_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
else:
if self.args.sortish_sampler:
self.train_dataset.make_sortish_sampler(
self.args.per_device_train_batch_size, distributed=self.args.n_gpu > 1
self.args.per_device_train_batch_size,
distributed=(self.args.local_rank != -1),
)

return (
Expand Down
39 changes: 31 additions & 8 deletions examples/seq2seq/test_finetune_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,14 @@

from transformers import BertTokenizer, EncoderDecoderModel
from transformers.file_utils import is_datasets_available
from transformers.testing_utils import TestCasePlus, execute_subprocess_async, get_gpu_count, slow
from transformers.testing_utils import (
TestCasePlus,
execute_subprocess_async,
get_gpu_count,
require_torch_multi_gpu,
require_torch_non_multi_gpu,
slow,
)
from transformers.trainer_callback import TrainerState
from transformers.trainer_utils import set_seed

Expand All @@ -18,17 +25,32 @@


class TestFinetuneTrainer(TestCasePlus):
def test_finetune_trainer(self):
output_dir = self.run_trainer(1, "12", MBART_TINY, 1)
def finetune_trainer_quick(self, distributed=None):
output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed)
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
first_step_stats = eval_metrics[0]
assert "eval_bleu" in first_step_stats

@require_torch_non_multi_gpu
def test_finetune_trainer_no_dist(self):
self.finetune_trainer_quick()

# the following 2 tests verify that the trainer can handle distributed and non-distributed with n_gpu > 1
@require_torch_multi_gpu
def test_finetune_trainer_dp(self):
self.finetune_trainer_quick(distributed=False)

@require_torch_multi_gpu
def test_finetune_trainer_ddp(self):
self.finetune_trainer_quick(distributed=True)

@slow
def test_finetune_trainer_slow(self):
# There is a missing call to __init__process_group somewhere
output_dir = self.run_trainer(eval_steps=2, max_len="128", model_name=MARIAN_MODEL, num_train_epochs=10)
output_dir = self.run_trainer(
eval_steps=2, max_len="128", model_name=MARIAN_MODEL, num_train_epochs=10, distributed=False
)

# Check metrics
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
Expand Down Expand Up @@ -158,7 +180,9 @@ def _compute_metrics(pred):
# start training
trainer.train()

def run_trainer(self, eval_steps: int, max_len: str, model_name: str, num_train_epochs: int):
def run_trainer(
self, eval_steps: int, max_len: str, model_name: str, num_train_epochs: int, distributed: bool = False
):
data_dir = self.examples_dir / "seq2seq/test_data/wmt_en_ro"
output_dir = self.get_auto_remove_tmp_dir()
args = f"""
Expand Down Expand Up @@ -193,8 +217,8 @@ def run_trainer(self, eval_steps: int, max_len: str, model_name: str, num_train_
""".split()
# --eval_beams 2

n_gpu = get_gpu_count()
if n_gpu > 1:
if distributed:
n_gpu = get_gpu_count()
distributed_args = f"""
-m torch.distributed.launch
--nproc_per_node={n_gpu}
Expand All @@ -203,7 +227,6 @@ def run_trainer(self, eval_steps: int, max_len: str, model_name: str, num_train_
cmd = [sys.executable] + distributed_args + args
execute_subprocess_async(cmd, env=self.get_env())
else:
# 0 or 1 gpu
testargs = ["finetune_trainer.py"] + args
with patch.object(sys, "argv", testargs):
main()
Expand Down