From f173433855a86d5aadda0b9bb74a703d387a62ca Mon Sep 17 00:00:00 2001 From: yaswanth Date: Mon, 12 May 2025 23:09:10 +0530 Subject: [PATCH 1/2] Remove test_dataloader --- src/transformers/trainer.py | 36 +----------------------------------- 1 file changed, 1 insertion(+), 35 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 5886146002f1..fc231685b912 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1142,40 +1142,6 @@ def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None return self.accelerator.prepare(eval_dataloader) - def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: - """ - Returns the test [`~torch.utils.data.DataLoader`]. - - Subclass and override this method if you want to inject some custom behavior. - - Args: - test_dataset (`torch.utils.data.Dataset`, *optional*): - The test dataset to use. If it is a [`~datasets.Dataset`], columns not accepted by the - `model.forward()` method are automatically removed. It must implement `__len__`. - """ - data_collator = self.data_collator - - if is_datasets_available() and isinstance(test_dataset, datasets.Dataset): - test_dataset = self._remove_unused_columns(test_dataset, description="test") - else: - data_collator = self._get_collator_with_removed_columns(data_collator, description="test") - - dataloader_params = { - "batch_size": self.args.eval_batch_size, - "collate_fn": data_collator, - "num_workers": self.args.dataloader_num_workers, - "pin_memory": self.args.dataloader_pin_memory, - "persistent_workers": self.args.dataloader_persistent_workers, - } - - if not isinstance(test_dataset, torch.utils.data.IterableDataset): - dataloader_params["sampler"] = self._get_eval_sampler(test_dataset) - dataloader_params["drop_last"] = self.args.dataloader_drop_last - dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor - - # We use the same batch_size as for eval. - return self.accelerator.prepare(DataLoader(test_dataset, **dataloader_params)) - def create_optimizer_and_scheduler(self, num_training_steps: int): """ Setup the optimizer and the learning rate scheduler. @@ -4246,7 +4212,7 @@ def predict( # memory metrics - must set up as early as possible self._memory_tracker.start() - test_dataloader = self.get_test_dataloader(test_dataset) + test_dataloader = self.get_eval_dataloader(test_dataset) start_time = time.time() eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop From 3816a02ff9028b65b1b810c7794be2c54051108b Mon Sep 17 00:00:00 2001 From: yaswanth Date: Tue, 13 May 2025 19:59:45 +0530 Subject: [PATCH 2/2] refactor --- src/transformers/trainer.py | 138 +++++++++++++++++++++--------------- 1 file changed, 82 insertions(+), 56 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index fc231685b912..0b3fd219f8a4 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -972,16 +972,16 @@ def _get_collator_with_removed_columns( ) return remove_columns_collator - def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: - if self.train_dataset is None or not has_length(self.train_dataset): + def _get_train_sampler(self, train_dataset) -> Optional[torch.utils.data.Sampler]: + if train_dataset is None or not has_length(train_dataset): return None # Build the sampler. if self.args.group_by_length: - if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset): + if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): lengths = ( - self.train_dataset[self.args.length_column_name] - if self.args.length_column_name in self.train_dataset.column_names + train_dataset[self.args.length_column_name] + if self.args.length_column_name in train_dataset.column_names else None ) else: @@ -991,50 +991,80 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: ) return LengthGroupedSampler( self.args.train_batch_size * self.args.gradient_accumulation_steps, - dataset=self.train_dataset, + dataset=train_dataset, lengths=lengths, model_input_name=model_input_name, ) else: - return RandomSampler(self.train_dataset) + return RandomSampler(train_dataset) - def get_train_dataloader(self) -> DataLoader: - """ - Returns the training [`~torch.utils.data.DataLoader`]. - - Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed - training if necessary) otherwise. - - Subclass and override this method if you want to inject some custom behavior. - """ - if self.train_dataset is None: - raise ValueError("Trainer: training requires a train_dataset.") + def _get_dataloader( + self, + dataset: Dataset, + description: str, + batch_size: int, + sampler_fn: Optional[Callable[[Dataset], torch.utils.data.Sampler]] = None, + is_training: bool = False, + dataloader_key: Optional[str] = None, + ) -> DataLoader: + """Create a [`~torch.utils.data.DataLoader`] from the given dataset.""" - train_dataset = self.train_dataset data_collator = self.data_collator - if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): - train_dataset = self._remove_unused_columns(train_dataset, description="training") + if is_datasets_available() and isinstance(dataset, datasets.Dataset): + dataset = self._remove_unused_columns(dataset, description=description) else: - data_collator = self._get_collator_with_removed_columns(data_collator, description="training") + data_collator = self._get_collator_with_removed_columns(self.data_collator, description=description) dataloader_params = { - "batch_size": self._train_batch_size, + "batch_size": batch_size, "collate_fn": data_collator, "num_workers": self.args.dataloader_num_workers, "pin_memory": self.args.dataloader_pin_memory, "persistent_workers": self.args.dataloader_persistent_workers, } - if not isinstance(train_dataset, torch.utils.data.IterableDataset): - dataloader_params["sampler"] = self._get_train_sampler() + if not isinstance(dataset, torch.utils.data.IterableDataset): + if sampler_fn is not None: + dataloader_params["sampler"] = sampler_fn(dataset) dataloader_params["drop_last"] = self.args.dataloader_drop_last - dataloader_params["worker_init_fn"] = partial( - seed_worker, num_workers=self.args.dataloader_num_workers, rank=self.args.process_index - ) dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor + if is_training: + dataloader_params["worker_init_fn"] = partial( + seed_worker, num_workers=self.args.dataloader_num_workers, rank=self.args.process_index + ) + + dataloader = DataLoader(dataset, **dataloader_params) - return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) + # Accelerator.free_memory() will destroy the references, so + # we need to store the non-prepared version for eval dataloaders. + if dataloader_key is not None and self.args.dataloader_persistent_workers: + if hasattr(self, "_eval_dataloaders"): + self._eval_dataloaders[dataloader_key] = dataloader + else: + self._eval_dataloaders = {dataloader_key: dataloader} + + return self.accelerator.prepare(dataloader) + + def get_train_dataloader(self) -> DataLoader: + """ + Returns the training [`~torch.utils.data.DataLoader`]. + + Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed + training if necessary) otherwise. + + Subclass and override this method if you want to inject some custom behavior. + """ + if self.train_dataset is None: + raise ValueError("Trainer: training requires a train_dataset.") + + return self._get_dataloader( + dataset=self.train_dataset, + description="Training", + batch_size=self._train_batch_size, + sampler_fn=self._get_train_sampler, + is_training=True, + ) def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]: if eval_dataset is None or not has_length(eval_dataset): @@ -1111,36 +1141,32 @@ def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None if eval_dataset is not None else self.eval_dataset ) - data_collator = self.data_collator - if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset): - eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation") - else: - data_collator = self._get_collator_with_removed_columns(data_collator, description="evaluation") - - dataloader_params = { - "batch_size": self.args.eval_batch_size, - "collate_fn": data_collator, - "num_workers": self.args.dataloader_num_workers, - "pin_memory": self.args.dataloader_pin_memory, - "persistent_workers": self.args.dataloader_persistent_workers, - } + return self._get_dataloader( + dataset=eval_dataset, + description="Evaluation", + batch_size=self.args.eval_batch_size, + sampler_fn=self._get_eval_sampler, + dataloader_key=dataloader_key, + ) - if not isinstance(eval_dataset, torch.utils.data.IterableDataset): - dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset) - dataloader_params["drop_last"] = self.args.dataloader_drop_last - dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor + def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: + """ + Returns the test [`~torch.utils.data.DataLoader`]. - # accelerator.free_memory() will destroy the references, so - # we need to store the non-prepared version - eval_dataloader = DataLoader(eval_dataset, **dataloader_params) - if self.args.dataloader_persistent_workers: - if hasattr(self, "_eval_dataloaders"): - self._eval_dataloaders[dataloader_key] = eval_dataloader - else: - self._eval_dataloaders = {dataloader_key: eval_dataloader} + Subclass and override this method if you want to inject some custom behavior. - return self.accelerator.prepare(eval_dataloader) + Args: + test_dataset (`torch.utils.data.Dataset`, *optional*): + The test dataset to use. If it is a [`~datasets.Dataset`], columns not accepted by the + `model.forward()` method are automatically removed. It must implement `__len__`. + """ + return self._get_dataloader( + dataset=test_dataset, + description="test", + batch_size=self.args.eval_batch_size, + sampler_fn=self._get_eval_sampler, + ) def create_optimizer_and_scheduler(self, num_training_steps: int): """ @@ -4212,7 +4238,7 @@ def predict( # memory metrics - must set up as early as possible self._memory_tracker.start() - test_dataloader = self.get_eval_dataloader(test_dataset) + test_dataloader = self.get_test_dataloader(test_dataset) start_time = time.time() eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop