diff --git a/src/sparseml/transformers/question_answering.py b/src/sparseml/transformers/question_answering.py index 90e452aedf9..91eb1c4a9b2 100644 --- a/src/sparseml/transformers/question_answering.py +++ b/src/sparseml/transformers/question_answering.py @@ -40,7 +40,6 @@ EvalPrediction, HfArgumentParser, PreTrainedTokenizerFast, - TrainingArguments, default_data_collator, set_seed, ) @@ -50,6 +49,7 @@ from sparseml.transformers.sparsification import ( QuestionAnsweringTrainer, + TrainingArguments, postprocess_qa_predictions, ) from sparseml.transformers.utils import SparseAutoModel, get_shared_tokenizer_src @@ -90,10 +90,6 @@ class ModelArguments: ) } ) - distill_teacher: Optional[str] = field( - default=None, - metadata={"help": "Teacher model which needs to be a trained QA model"}, - ) config_name: Optional[str] = field( default=None, metadata={ @@ -141,21 +137,6 @@ class DataTrainingArguments: Arguments pertaining to what data to input to our model for training and eval """ - recipe: Optional[str] = field( - default=None, - metadata={ - "help": ( - "Path to a SparseML sparsification recipe, see " - "https://github.com/neuralmagic/sparseml for more information" - ) - }, - ) - recipe_args: Optional[str] = field( - default=None, - metadata={ - "help": "Recipe arguments to be overwritten", - }, - ) dataset_name: Optional[str] = field( default=None, metadata={ @@ -444,7 +425,7 @@ def main(): "revision": model_args.model_revision, "use_auth_token": True if model_args.use_auth_token else None, }, - teacher_name_or_path=model_args.distill_teacher, + teacher_name_or_path=training_args.distill_teacher, teacher_kwargs={ "cache_dir": model_args.cache_dir, "use_auth_token": True if model_args.use_auth_token else None, @@ -770,8 +751,8 @@ def compute_metrics(p: EvalPrediction): trainer = QuestionAnsweringTrainer( model=model, model_state_path=model_args.model_name_or_path, - recipe=data_args.recipe, - recipe_args=data_args.recipe_args, + recipe=training_args.recipe, + recipe_args=training_args.recipe_args, metadata_args=metadata_args, teacher=teacher, args=training_args, diff --git a/src/sparseml/transformers/sparsification/__init__.py b/src/sparseml/transformers/sparsification/__init__.py index 61a91e00a04..b8735ea4c3d 100644 --- a/src/sparseml/transformers/sparsification/__init__.py +++ b/src/sparseml/transformers/sparsification/__init__.py @@ -21,3 +21,4 @@ from .question_answering import * from .trainer import * +from .training_args import * diff --git a/src/sparseml/transformers/sparsification/question_answering.py b/src/sparseml/transformers/sparsification/question_answering.py index 6785cb8b7f2..a2de2389758 100644 --- a/src/sparseml/transformers/sparsification/question_answering.py +++ b/src/sparseml/transformers/sparsification/question_answering.py @@ -21,20 +21,21 @@ """ import collections -import inspect import json import logging import os from typing import Any, Dict, List, Optional, Tuple, Union -import datasets import numpy as np from torch.nn import Module from tqdm.auto import tqdm -from transformers import Trainer, is_torch_tpu_available +from transformers import is_torch_tpu_available from transformers.trainer_utils import PredictionOutput -from sparseml.transformers.sparsification.trainer import TrainerInterface +from sparseml.transformers.sparsification.trainer import ( + TrainerInterface, + TransformersTrainer, +) if is_torch_tpu_available(): @@ -51,7 +52,7 @@ _LOGGER = logging.getLogger(__name__) -class _QuestionAnsweringTrainer(Trainer): +class _QuestionAnsweringTrainer(TransformersTrainer): """ Trainer implementation for Question-Answering processing """ @@ -210,30 +211,6 @@ def __init__( **kwargs, ) - def _remove_unused_columns( - self, dataset: "datasets.Dataset", description: Optional[str] = None - ): - if ( - self._signature_columns is None - and self.teacher is not None - and self.teacher not in ("disable", "self") - ): - model_signature = inspect.signature(self.model.forward) - model_signature_columns = set(model_signature.parameters.keys()) - - teacher_signature = inspect.signature(self.teacher.forward) - teacher_signature_columns = set(teacher_signature.parameters.keys()) - - self._signature_columns = list( - model_signature_columns | teacher_signature_columns - ) - - # Labels may be named label or label_ids, the default data - # collator handles that. - self._signature_columns += ["label", "label_ids"] - - return super()._remove_unused_columns(dataset, description) - def postprocess_qa_predictions( examples, diff --git a/src/sparseml/transformers/sparsification/trainer.py b/src/sparseml/transformers/sparsification/trainer.py index 8875d31ecce..2f32c457745 100644 --- a/src/sparseml/transformers/sparsification/trainer.py +++ b/src/sparseml/transformers/sparsification/trainer.py @@ -29,7 +29,7 @@ from torch import distributed as dist from torch.nn import Module from torch.utils.data import RandomSampler -from transformers import Trainer as TransformersTrainer +from transformers import Trainer as HFTransformersTrainer from transformers import TrainerCallback, TrainerControl, TrainingArguments from transformers.file_utils import WEIGHTS_NAME from transformers.trainer_callback import TrainerState @@ -51,6 +51,7 @@ "TrainerInterface", "Trainer", "DisableHalfPrecisionCallback", + "TransformersTrainer", ] @@ -830,39 +831,25 @@ def _generate_apply_manager_params(self, kwargs) -> Tuple[Optional[str], float]: return checkpoint, epoch -class Trainer(TrainerInterface, TransformersTrainer): +class TransformersTrainer(HFTransformersTrainer): """ - Training implementation for running sparsification recipes with transformers flows. - :param model: the model to use with the trainer and apply sparsification to - :param model_state_path: the state path to the model, - used to load config and tokenizer settings - :param recipe: the recipe, if any, to apply to the modle and training - process - :param recipe_args: A json string, csv key=value string, or dictionary containing - arguments to override the root arguments within the recipe such as - learning rate or num epochs - :param teacher: teacher model for distillation. Set to 'self' to distill - from the loaded model or 'disable' to turn of distillation - :param kwargs: key word arguments passed to the parent class + A transformers trainer class with customed behaviors that can be shared + by all trainers inside SparseML """ - def __init__( - self, - model: Module, - model_state_path: str, - recipe: Optional[str], - recipe_args: Optional[Union[Dict[str, Any], str]] = None, - teacher: Optional[Union[Module, str]] = None, - **kwargs, - ): - super().__init__( - model=model, - model_state_path=model_state_path, - recipe=recipe, - recipe_args=recipe_args, - teacher=teacher, - **kwargs, - ) + def _save_checkpoint(self, model, trial, metrics=None): + # Call into the save checkpoint by HF Transformers, which saves the + # best metric if required + super()._save_checkpoint(model, trial, metrics=metrics) + if ( + self.args.metric_for_best_model is None + or self.args.best_model_after_epoch is None + ): + return + + if self.state.epoch <= self.args.best_model_after_epoch: + self.state.best_metric = None + self.state.best_model_checkpoint = None def _remove_unused_columns( self, dataset: "datasets.Dataset", description: Optional[str] = None @@ -901,6 +888,41 @@ def _remove_unused_columns( return super()._remove_unused_columns(dataset, description) +class Trainer(TrainerInterface, TransformersTrainer): + """ + Training implementation for running sparsification recipes with transformers flows. + :param model: the model to use with the trainer and apply sparsification to + :param model_state_path: the state path to the model, + used to load config and tokenizer settings + :param recipe: the recipe, if any, to apply to the modle and training + process + :param recipe_args: A json string, csv key=value string, or dictionary containing + arguments to override the root arguments within the recipe such as + learning rate or num epochs + :param teacher: teacher model for distillation. Set to 'self' to distill + from the loaded model or 'disable' to turn of distillation + :param kwargs: key word arguments passed to the parent class + """ + + def __init__( + self, + model: Module, + model_state_path: str, + recipe: Optional[str], + recipe_args: Optional[Union[Dict[str, Any], str]] = None, + teacher: Optional[Union[Module, str]] = None, + **kwargs, + ): + super().__init__( + model=model, + model_state_path=model_state_path, + recipe=recipe, + recipe_args=recipe_args, + teacher=teacher, + **kwargs, + ) + + class DisableHalfPrecisionCallback(TrainerCallback): """ TrainerCallback for disabling FP16 training before QAT training begins diff --git a/src/sparseml/transformers/sparsification/training_args.py b/src/sparseml/transformers/sparsification/training_args.py new file mode 100644 index 00000000000..a1aa639ad87 --- /dev/null +++ b/src/sparseml/transformers/sparsification/training_args.py @@ -0,0 +1,57 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Optional + +from transformers import TrainingArguments as HFTrainingArgs + + +__all__ = ["TrainingArguments"] + + +@dataclass +class TrainingArguments(HFTrainingArgs): + """ + Training arguments specific to SparseML Transformers workflow + + :param best_model_after_epoch (`int`, *optional*, defaults to None): + The epoch after which best model will be saved; used in conjunction + with `load_best_model_at_end` and `metric_for_best_model` training + arguments + """ + + distill_teacher: Optional[str] = field( + default=None, + metadata={ + "help": "Teacher model (a trained text classification model)", + }, + ) + best_model_after_epoch: int = field( + default=None, + metadata={"help": "Epoch after which best model will be saved."}, + ) + recipe: Optional[str] = field( + default=None, + metadata={ + "help": ( + "Path to a SparseML sparsification recipe, see " + "https://github.com/neuralmagic/sparseml for more information" + ), + }, + ) + recipe_args: Optional[str] = field( + default=None, + metadata={"help": "Recipe arguments to be overwritten"}, + ) diff --git a/src/sparseml/transformers/text_classification.py b/src/sparseml/transformers/text_classification.py index 7fb553f6ce1..fc070a4f9fa 100644 --- a/src/sparseml/transformers/text_classification.py +++ b/src/sparseml/transformers/text_classification.py @@ -42,7 +42,6 @@ EvalPrediction, HfArgumentParser, PretrainedConfig, - TrainingArguments, default_data_collator, set_seed, ) @@ -50,7 +49,7 @@ from transformers.utils import check_min_version from transformers.utils.versions import require_version -from sparseml.transformers.sparsification import Trainer +from sparseml.transformers.sparsification import Trainer, TrainingArguments from sparseml.transformers.utils import SparseAutoModel, get_shared_tokenizer_src @@ -94,19 +93,6 @@ class DataTrainingArguments: arguments to be able to specify them on the command line """ - recipe: Optional[str] = field( - default=None, - metadata={ - "help": ( - "Path to a SparseML sparsification recipe, see " - "https://github.com/neuralmagic/sparseml for more information" - ), - }, - ) - recipe_args: Optional[str] = field( - default=None, - metadata={"help": "Recipe arguments to be overwritten"}, - ) task_name: Optional[str] = field( default=None, metadata={ @@ -254,12 +240,6 @@ class ModelArguments: ) } ) - distill_teacher: Optional[str] = field( - default=None, - metadata={ - "help": "Teacher model which must be a trained text classification model" - }, - ) config_name: Optional[str] = field( default=None, metadata={ @@ -480,7 +460,7 @@ def main(): "revision": model_args.model_revision, "use_auth_token": True if model_args.use_auth_token else None, }, - teacher_name_or_path=model_args.distill_teacher, + teacher_name_or_path=training_args.distill_teacher, teacher_kwargs={ "cache_dir": model_args.cache_dir, "use_auth_token": True if model_args.use_auth_token else None, @@ -718,9 +698,9 @@ def compute_metrics(p: EvalPrediction): trainer = Trainer( model=model, model_state_path=model_args.model_name_or_path, - recipe=data_args.recipe, + recipe=training_args.recipe, metadata_args=metadata_args, - recipe_args=data_args.recipe_args, + recipe_args=training_args.recipe_args, teacher=teacher, args=training_args, data_args=data_args,