Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 71 additions & 32 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task.
"""

import collections
import inspect
import math
import os
import re
import shutil
Expand Down Expand Up @@ -283,6 +285,15 @@ def __init__(
FutureWarning,
)

if args.max_steps > 0:
logger.info("max_steps is given, it will override any value given in num_train_epochs")

# Enforce rules on using datasets with no __len__
if train_dataset is not None and not isinstance(train_dataset, collections.abc.Sized) and args.max_steps <= 0:
raise ValueError("train_dataset does not implement __len__, max_steps has to be specified")
if eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized):
raise ValueError("eval_dataset must implement __len__")

if is_datasets_available():
if isinstance(train_dataset, datasets.Dataset):
self._remove_unused_columns(self.train_dataset, description="training")
Expand Down Expand Up @@ -361,7 +372,7 @@ def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optio
dataset.set_format(type=dataset.format["type"], columns=columns)

def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
if not isinstance(self.train_dataset, collections.abc.Sized):
return None
elif is_torch_tpu_available():
return get_tpu_sampler(self.train_dataset)
Expand All @@ -376,7 +387,7 @@ def get_train_dataloader(self) -> DataLoader:
"""
Returns the training :class:`~torch.utils.data.DataLoader`.

Will use no sampler if :obj:`self.train_dataset` is a :obj:`torch.utils.data.IterableDataset`, a random sampler
Will use no sampler if :obj:`self.train_dataset` does not implement :obj:`__len__`, a random sampler
(adapted to distributed training if necessary) otherwise.

Subclass and override this method if you want to inject some custom behavior.
Expand All @@ -395,9 +406,7 @@ def get_train_dataloader(self) -> DataLoader:
)

def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]:
if isinstance(eval_dataset, torch.utils.data.IterableDataset):
return None
elif is_torch_tpu_available():
if is_torch_tpu_available():
return SequentialDistributedSampler(eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
elif self.args.local_rank != -1:
return SequentialDistributedSampler(eval_dataset)
Expand All @@ -408,19 +417,18 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa
"""
Returns the evaluation :class:`~torch.utils.data.DataLoader`.

Will use no sampler if :obj:`self.eval_dataset` is a :obj:`torch.utils.data.IterableDataset`, a sequential
sampler (adapted to distributed training if necessary) otherwise.

Subclass and override this method if you want to inject some custom behavior.

Args:
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
If provided, will override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`, columns not
accepted by the ``model.forward()`` method are automatically removed.
accepted by the ``model.forward()`` method are automatically removed. It must implement :obj:`__len__`.
"""
if eval_dataset is None and self.eval_dataset is None:
raise ValueError("Trainer: evaluation requires an eval_dataset.")
elif eval_dataset is not None and is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
elif eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized):
raise ValueError("eval_dataset must implement __len__")
elif is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
self._remove_unused_columns(eval_dataset, description="evaluation")
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
eval_sampler = self._get_eval_sampler(eval_dataset)
Expand All @@ -438,17 +446,16 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
"""
Returns the test :class:`~torch.utils.data.DataLoader`.

Will use no sampler if :obj:`test_dataset` is a :obj:`torch.utils.data.IterableDataset`, a sequential
sampler (adapted to distributed training if necessary) otherwise.

Subclass and override this method if you want to inject some custom behavior.

Args:
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
test_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
The test dataset to use. If it is an :obj:`datasets.Dataset`, columns not accepted by the
``model.forward()`` method are automatically removed.
``model.forward()`` method are automatically removed. It must implement :obj:`__len__`.
"""
if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
if not isinstance(test_dataset, collections.abc.Sized):
raise ValueError("test_dataset must implement __len__")
elif is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
self._remove_unused_columns(test_dataset, description="test")
test_sampler = self._get_eval_sampler(test_dataset)

Expand Down Expand Up @@ -494,6 +501,8 @@ def create_optimizer_and_scheduler(self, num_training_steps: int):
def num_examples(self, dataloader: DataLoader) -> int:
"""
Helper to get number of samples in a :class:`~torch.utils.data.DataLoader` by accessing its dataset.

Will raise an exception if the underlying dataset dese not implement method :obj:`__len__`
"""
return len(dataloader.dataset)

Expand Down Expand Up @@ -579,19 +588,32 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D
# Reinitializes optimizer and scheduler
self.optimizer, self.lr_scheduler = None, None

# Keeping track whether we can can len() on the dataset or not
train_dataset_is_sized = isinstance(self.train_dataset, collections.abc.Sized)

# Data loader and number of training steps
train_dataloader = self.get_train_dataloader()
num_update_steps_per_epoch = len(train_dataloader) // self.args.gradient_accumulation_steps
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
if self.args.max_steps > 0:
max_steps = self.args.max_steps
num_train_epochs = self.args.max_steps // num_update_steps_per_epoch + int(
self.args.max_steps % num_update_steps_per_epoch > 0
)

# Setting up training control variables:
# number of training epochs: num_train_epochs
# number of training steps per epoch: num_update_steps_per_epoch
# total number of training steps to execute: max_steps
if train_dataset_is_sized:
num_update_steps_per_epoch = len(train_dataloader) // self.args.gradient_accumulation_steps
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
if self.args.max_steps > 0:
max_steps = self.args.max_steps
num_train_epochs = self.args.max_steps // num_update_steps_per_epoch + int(
self.args.max_steps % num_update_steps_per_epoch > 0
)
else:
max_steps = math.ceil(self.args.num_train_epochs * num_update_steps_per_epoch)
num_train_epochs = math.ceil(self.args.num_train_epochs)
else:
max_steps = int(num_update_steps_per_epoch * self.args.num_train_epochs)
num_train_epochs = self.args.num_train_epochs
num_train_epochs = int(np.ceil(num_train_epochs))
# see __init__. max_steps is set when the dataset has no __len__
max_steps = self.args.max_steps
num_train_epochs = 1
num_update_steps_per_epoch = max_steps

self.create_optimizer_and_scheduler(num_training_steps=max_steps)
self.state = TrainerState()
Expand Down Expand Up @@ -645,8 +667,15 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D
* self.args.gradient_accumulation_steps
* (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1)
)

num_examples = (
self.num_examples(train_dataloader)
if train_dataset_is_sized
else total_train_batch_size * self.args.max_steps
)

logger.info("***** Running training *****")
logger.info(" Num examples = %d", self.num_examples(train_dataloader))
logger.info(" Num examples = %d", num_examples)
logger.info(" Num Epochs = %d", num_train_epochs)
logger.info(" Instantaneous batch size per device = %d", self.args.per_device_train_batch_size)
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size)
Expand Down Expand Up @@ -703,6 +732,7 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D
if self.args.past_index >= 0:
self._past = None

steps_in_epoch = len(epoch_iterator) if train_dataset_is_sized else self.args.max_steps
self.control = self.callback_handler.on_epoch_begin(self.args, self.state, self.control)

for step, inputs in enumerate(epoch_iterator):
Expand All @@ -728,8 +758,8 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D

if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
# last step in epoch but step is always smaller than gradient_accumulation_steps
len(epoch_iterator) <= self.args.gradient_accumulation_steps
and (step + 1) == len(epoch_iterator)
steps_in_epoch <= self.args.gradient_accumulation_steps
and (step + 1) == steps_in_epoch
):
if self.args.fp16 and _use_native_amp:
self.scaler.unscale_(self.optimizer)
Expand All @@ -750,7 +780,7 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D
self.lr_scheduler.step()
model.zero_grad()
self.state.global_step += 1
self.state.epoch = epoch + (step + 1) / len(epoch_iterator)
self.state.epoch = epoch + (step + 1) / steps_in_epoch
self.control = self.callback_handler.on_step_end(self.args, self.state, self.control)

self._maybe_log_save_evalute(tr_loss, model, trial, epoch)
Expand Down Expand Up @@ -1207,11 +1237,15 @@ def evaluate(self, eval_dataset: Optional[Dataset] = None) -> Dict[str, float]:
Args:
eval_dataset (:obj:`Dataset`, `optional`):
Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`,
columns not accepted by the ``model.forward()`` method are automatically removed.
columns not accepted by the ``model.forward()`` method are automatically removed. It must implement
the :obj:`__len__` method.

Returns:
A dictionary containing the evaluation loss and the potential metrics computed from the predictions.
"""
if eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized):
raise ValueError("eval_dataset must implement __len__")

eval_dataloader = self.get_eval_dataloader(eval_dataset)

output = self.prediction_loop(eval_dataloader, description="Evaluation")
Expand All @@ -1234,7 +1268,7 @@ def predict(self, test_dataset: Dataset) -> PredictionOutput:
Args:
test_dataset (:obj:`Dataset`):
Dataset to run the predictions on. If it is an :obj:`datasets.Dataset`, columns not accepted by the
``model.forward()`` method are automatically removed.
``model.forward()`` method are automatically removed. Has to implement the method :obj:`__len__`

Returns:
`NamedTuple`:
Expand All @@ -1245,6 +1279,9 @@ def predict(self, test_dataset: Dataset) -> PredictionOutput:
metrics (:obj:`Dict[str, float]`, `optional`):
The potential dictionary of metrics (if the dataset contained labels).
"""
if test_dataset is not None and not isinstance(test_dataset, collections.abc.Sized):
raise ValueError("test_dataset must implement __len__")

test_dataloader = self.get_test_dataloader(test_dataset)

return self.prediction_loop(test_dataloader, description="Prediction")
Expand All @@ -1264,6 +1301,8 @@ def prediction_loop(
)
return self._prediction_loop(dataloader, description, prediction_loss_only=prediction_loss_only)

if not isinstance(dataloader.dataset, collections.abc.Sized):
raise ValueError("dataset must implement __len__")
prediction_loss_only = (
prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only
)
Expand Down
62 changes: 52 additions & 10 deletions tests/test_trainer.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,14 @@
from torch.utils.data import IterableDataset

from transformers import (
AutoModelForMaskedLM,
AutoModelForSequenceClassification,
DataCollatorForLanguageModeling,
GlueDataset,
GlueDataTrainingArguments,
LineByLineTextDataset,
PreTrainedModel,
TextDataset,
Trainer,
TrainerState,
)
Expand Down Expand Up @@ -83,15 +86,16 @@ def __init__(self, a=0, b=0, double_output=False, **kwargs):
if is_torch_available():

class SampleIterableDataset(IterableDataset):
def __init__(self, file_path):
self.file_path = file_path
"""
Criteria is not whether it is IterableDataset or not, criteria is whether __len__ is implemented
"""

def parse_file(self):
f = open(self.file_path, "r")
return f.readlines()
def __init__(self, file_path, tokenizer):
self.ds = TextDataset(file_path=file_path, tokenizer=tokenizer, block_size=64)

def __iter__(self):
return iter(self.parse_file())
for i in range(len(self.ds)):
yield self.ds[i]

class RegressionModel(torch.nn.Module):
def __init__(self, a=0, b=0, double_output=False):
Expand Down Expand Up @@ -538,13 +542,51 @@ def test_trainer_eval_lm(self):
self.assertEqual(len(dataset), 31)

def test_trainer_iterable_dataset(self):
# Simulate Language Modeling with an IterableDataset, with no __len__ method
# Pick-up a tiny model, so it works on CPU
# See Issue #5990: https://github.com/huggingface/transformers/issues/5990
MODEL_ID = "sshleifer/tiny-distilbert-base-cased"
model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
train_dataset = SampleIterableDataset(PATH_SAMPLE_TEXT)
training_args = TrainingArguments(output_dir="./examples", no_cuda=True)
trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset)
model = AutoModelForMaskedLM.from_pretrained(MODEL_ID)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
train_dataset = SampleIterableDataset(file_path=PATH_SAMPLE_TEXT, tokenizer=tokenizer)
training_args = TrainingArguments(output_dir="./examples", no_cuda=True, max_steps=2)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)

training_args = TrainingArguments(output_dir="./examples", no_cuda=True, max_steps=2)
trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset, data_collator=data_collator)
trainer.train()

loader = trainer.get_train_dataloader()
self.assertIsInstance(loader, torch.utils.data.DataLoader)
self.assertIsInstance(loader.sampler, torch.utils.data.dataloader._InfiniteConstantSampler)

# Exception if giving iterable dataset and no max_steps
with self.assertRaises(ValueError):
training_args = TrainingArguments(output_dir="./examples", no_cuda=True)
_ = Trainer(model=model, args=training_args, train_dataset=train_dataset, data_collator=data_collator)

# Exception if eval_dataset is iterable in __init__
with self.assertRaises(ValueError):
training_args = TrainingArguments(output_dir="./examples", no_cuda=True, max_steps=2)
_ = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=train_dataset,
data_collator=data_collator,
)

# Exception if predicting with iterable dataset
with self.assertRaises(ValueError):
training_args = TrainingArguments(output_dir="./examples", no_cuda=True)
trainer = Trainer(model=model, args=training_args, data_collator=data_collator)
trainer.predict(train_dataset)

# Exception if evaluating with iterable dataset
with self.assertRaises(ValueError):
training_args = TrainingArguments(output_dir="./examples", no_cuda=True)
trainer = Trainer(model=model, args=training_args, data_collator=data_collator)
trainer.evaluate(train_dataset)

def test_num_train_epochs_in_training(self):
# len(train_dl) < gradient_accumulation_steps shouldn't give ``ZeroDivisionError`` when ``max_steps`` is given.
Expand Down