Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 72 additions & 80 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,16 +972,16 @@ def _get_collator_with_removed_columns(
)
return remove_columns_collator

def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.train_dataset is None or not has_length(self.train_dataset):
def _get_train_sampler(self, train_dataset) -> Optional[torch.utils.data.Sampler]:
if train_dataset is None or not has_length(train_dataset):
return None

# Build the sampler.
if self.args.group_by_length:
if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset):
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
lengths = (
self.train_dataset[self.args.length_column_name]
if self.args.length_column_name in self.train_dataset.column_names
train_dataset[self.args.length_column_name]
if self.args.length_column_name in train_dataset.column_names
else None
)
else:
Expand All @@ -991,50 +991,80 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
)
return LengthGroupedSampler(
self.args.train_batch_size * self.args.gradient_accumulation_steps,
dataset=self.train_dataset,
dataset=train_dataset,
lengths=lengths,
model_input_name=model_input_name,
)

else:
return RandomSampler(self.train_dataset)
return RandomSampler(train_dataset)

def get_train_dataloader(self) -> DataLoader:
"""
Returns the training [`~torch.utils.data.DataLoader`].

Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
training if necessary) otherwise.

Subclass and override this method if you want to inject some custom behavior.
"""
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
def _get_dataloader(
self,
dataset: Dataset,
description: str,
batch_size: int,
sampler_fn: Optional[Callable[[Dataset], torch.utils.data.Sampler]] = None,
is_training: bool = False,
dataloader_key: Optional[str] = None,
) -> DataLoader:
"""Create a [`~torch.utils.data.DataLoader`] from the given dataset."""

train_dataset = self.train_dataset
data_collator = self.data_collator
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
train_dataset = self._remove_unused_columns(train_dataset, description="training")
if is_datasets_available() and isinstance(dataset, datasets.Dataset):
dataset = self._remove_unused_columns(dataset, description=description)
else:
data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
data_collator = self._get_collator_with_removed_columns(self.data_collator, description=description)

dataloader_params = {
"batch_size": self._train_batch_size,
"batch_size": batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
}

if not isinstance(train_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_train_sampler()
if not isinstance(dataset, torch.utils.data.IterableDataset):
if sampler_fn is not None:
dataloader_params["sampler"] = sampler_fn(dataset)
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = partial(
seed_worker, num_workers=self.args.dataloader_num_workers, rank=self.args.process_index
)
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
if is_training:
dataloader_params["worker_init_fn"] = partial(
seed_worker, num_workers=self.args.dataloader_num_workers, rank=self.args.process_index
)

return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
dataloader = DataLoader(dataset, **dataloader_params)

# Accelerator.free_memory() will destroy the references, so
# we need to store the non-prepared version for eval dataloaders.
if dataloader_key is not None and self.args.dataloader_persistent_workers:
if hasattr(self, "_eval_dataloaders"):
self._eval_dataloaders[dataloader_key] = dataloader
else:
self._eval_dataloaders = {dataloader_key: dataloader}

return self.accelerator.prepare(dataloader)

def get_train_dataloader(self) -> DataLoader:
"""
Returns the training [`~torch.utils.data.DataLoader`].

Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
training if necessary) otherwise.

Subclass and override this method if you want to inject some custom behavior.
"""
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")

return self._get_dataloader(
dataset=self.train_dataset,
description="Training",
batch_size=self._train_batch_size,
sampler_fn=self._get_train_sampler,
is_training=True,
)

def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:
if eval_dataset is None or not has_length(eval_dataset):
Expand Down Expand Up @@ -1111,36 +1141,14 @@ def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None
if eval_dataset is not None
else self.eval_dataset
)
data_collator = self.data_collator

if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation")
else:
data_collator = self._get_collator_with_removed_columns(data_collator, description="evaluation")

dataloader_params = {
"batch_size": self.args.eval_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
}

if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset)
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor

# accelerator.free_memory() will destroy the references, so
# we need to store the non-prepared version
eval_dataloader = DataLoader(eval_dataset, **dataloader_params)
if self.args.dataloader_persistent_workers:
if hasattr(self, "_eval_dataloaders"):
self._eval_dataloaders[dataloader_key] = eval_dataloader
else:
self._eval_dataloaders = {dataloader_key: eval_dataloader}

return self.accelerator.prepare(eval_dataloader)
return self._get_dataloader(
dataset=eval_dataset,
description="Evaluation",
batch_size=self.args.eval_batch_size,
sampler_fn=self._get_eval_sampler,
dataloader_key=dataloader_key,
)

def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
"""
Expand All @@ -1153,28 +1161,12 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
The test dataset to use. If it is a [`~datasets.Dataset`], columns not accepted by the
`model.forward()` method are automatically removed. It must implement `__len__`.
"""
data_collator = self.data_collator

if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
test_dataset = self._remove_unused_columns(test_dataset, description="test")
else:
data_collator = self._get_collator_with_removed_columns(data_collator, description="test")

dataloader_params = {
"batch_size": self.args.eval_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
}

if not isinstance(test_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_eval_sampler(test_dataset)
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor

# We use the same batch_size as for eval.
return self.accelerator.prepare(DataLoader(test_dataset, **dataloader_params))
return self._get_dataloader(
dataset=test_dataset,
description="test",
batch_size=self.args.eval_batch_size,
sampler_fn=self._get_eval_sampler,
)

def create_optimizer_and_scheduler(self, num_training_steps: int):
"""
Expand Down