Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 4 additions & 22 deletions src/sparseml/transformers/masked_language_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,13 @@
AutoTokenizer,
DataCollatorForLanguageModeling,
HfArgumentParser,
TrainingArguments,
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version
from transformers.utils.versions import require_version

from sparseml.transformers.sparsification import Trainer
from sparseml.transformers.sparsification import Trainer, TrainingArguments
from sparseml.transformers.utils import SparseAutoModel, get_shared_tokenizer_src


Expand Down Expand Up @@ -108,10 +107,6 @@ class ModelArguments:
"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
},
)
distill_teacher: Optional[str] = field(
default=None,
metadata={"help": "Teacher model which needs to be a trained QA model"},
)
config_name: Optional[str] = field(
default=None,
metadata={
Expand Down Expand Up @@ -164,19 +159,6 @@ class DataTrainingArguments:
training and eval
"""

recipe: Optional[str] = field(
default=None,
metadata={
"help": (
"Path to a SparseML sparsification recipe, see "
"https://github.com/neuralmagic/sparseml for more information"
),
},
)
recipe_args: Optional[str] = field(
default=None,
metadata={"help": "Recipe arguments to be overwritten"},
)
dataset_name: Optional[str] = field(
default=None,
metadata={"help": "The name of the dataset to use (via the datasets library)"},
Expand Down Expand Up @@ -490,7 +472,7 @@ def main(**kwargs):
"revision": model_args.model_revision,
"use_auth_token": True if model_args.use_auth_token else None,
},
teacher_name_or_path=model_args.distill_teacher,
teacher_name_or_path=training_args.distill_teacher,
teacher_kwargs={
"cache_dir": model_args.cache_dir,
"use_auth_token": True if model_args.use_auth_token else None,
Expand Down Expand Up @@ -682,9 +664,9 @@ def compute_metrics(eval_preds):
trainer = Trainer(
model=model,
model_state_path=model_args.model_name_or_path,
recipe=data_args.recipe,
recipe=training_args.recipe,
metadata_args=metadata_args,
recipe_args=data_args.recipe_args,
recipe_args=training_args.recipe_args,
teacher=teacher,
args=training_args,
data_args=data_args,
Expand Down
27 changes: 4 additions & 23 deletions src/sparseml/transformers/question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
EvalPrediction,
HfArgumentParser,
PreTrainedTokenizerFast,
TrainingArguments,
default_data_collator,
set_seed,
)
Expand All @@ -50,6 +49,7 @@

from sparseml.transformers.sparsification import (
QuestionAnsweringTrainer,
TrainingArguments,
postprocess_qa_predictions,
)
from sparseml.transformers.utils import SparseAutoModel, get_shared_tokenizer_src
Expand Down Expand Up @@ -90,10 +90,6 @@ class ModelArguments:
)
}
)
distill_teacher: Optional[str] = field(
default=None,
metadata={"help": "Teacher model which needs to be a trained QA model"},
)
config_name: Optional[str] = field(
default=None,
metadata={
Expand Down Expand Up @@ -141,21 +137,6 @@ class DataTrainingArguments:
Arguments pertaining to what data to input to our model for training and eval
"""

recipe: Optional[str] = field(
default=None,
metadata={
"help": (
"Path to a SparseML sparsification recipe, see "
"https://github.com/neuralmagic/sparseml for more information"
)
},
)
recipe_args: Optional[str] = field(
default=None,
metadata={
"help": "Recipe arguments to be overwritten",
},
)
dataset_name: Optional[str] = field(
default=None,
metadata={
Expand Down Expand Up @@ -445,7 +426,7 @@ def main(**kwargs):
"revision": model_args.model_revision,
"use_auth_token": True if model_args.use_auth_token else None,
},
teacher_name_or_path=model_args.distill_teacher,
teacher_name_or_path=training_args.distill_teacher,
teacher_kwargs={
"cache_dir": model_args.cache_dir,
"use_auth_token": True if model_args.use_auth_token else None,
Expand Down Expand Up @@ -772,8 +753,8 @@ def compute_metrics(p: EvalPrediction):
trainer = QuestionAnsweringTrainer(
model=model,
model_state_path=model_args.model_name_or_path,
recipe=data_args.recipe,
recipe_args=data_args.recipe_args,
recipe=training_args.recipe,
recipe_args=training_args.recipe_args,
metadata_args=metadata_args,
teacher=teacher,
args=training_args,
Expand Down
1 change: 1 addition & 0 deletions src/sparseml/transformers/sparsification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@

from .question_answering import *
from .trainer import *
from .training_args import *
35 changes: 6 additions & 29 deletions src/sparseml/transformers/sparsification/question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,21 @@
"""

import collections
import inspect
import json
import logging
import os
from typing import Any, Dict, List, Optional, Tuple, Union

import datasets
import numpy as np
from torch.nn import Module
from tqdm.auto import tqdm
from transformers import Trainer, is_torch_tpu_available
from transformers import is_torch_tpu_available
from transformers.trainer_utils import PredictionOutput

from sparseml.transformers.sparsification.trainer import TrainerInterface
from sparseml.transformers.sparsification.trainer import (
TrainerInterface,
TransformersTrainer,
)


if is_torch_tpu_available():
Expand All @@ -51,7 +52,7 @@
_LOGGER = logging.getLogger(__name__)


class _QuestionAnsweringTrainer(Trainer):
class _QuestionAnsweringTrainer(TransformersTrainer):
"""
Trainer implementation for Question-Answering processing
"""
Expand Down Expand Up @@ -224,30 +225,6 @@ def __init__(
**kwargs,
)

def _remove_unused_columns(
self, dataset: "datasets.Dataset", description: Optional[str] = None
):
if (
self._signature_columns is None
and self.teacher is not None
and self.teacher not in ("disable", "self")
):
model_signature = inspect.signature(self.model.forward)
model_signature_columns = set(model_signature.parameters.keys())

teacher_signature = inspect.signature(self.teacher.forward)
teacher_signature_columns = set(teacher_signature.parameters.keys())

self._signature_columns = list(
model_signature_columns | teacher_signature_columns
)

# Labels may be named label or label_ids, the default data
# collator handles that.
self._signature_columns += ["label", "label_ids"]

return super()._remove_unused_columns(dataset, description)


def postprocess_qa_predictions(
examples,
Expand Down
84 changes: 53 additions & 31 deletions src/sparseml/transformers/sparsification/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import torch
from torch import distributed as dist
from torch.nn import Module
from transformers import Trainer as TransformersTrainer
from transformers import Trainer as HFTransformersTrainer
from transformers import TrainerCallback, TrainerControl, TrainingArguments
from transformers.file_utils import WEIGHTS_NAME
from transformers.integrations import TensorBoardCallback
Expand All @@ -51,6 +51,7 @@
"TrainerInterface",
"Trainer",
"DisableHalfPrecisionCallback",
"TransformersTrainer",
]


Expand Down Expand Up @@ -855,39 +856,25 @@ def _generate_apply_manager_params(self, kwargs) -> Tuple[Optional[str], float]:
return checkpoint, epoch


class Trainer(TrainerInterface, TransformersTrainer):
class TransformersTrainer(HFTransformersTrainer):
"""
Training implementation for running sparsification recipes with transformers flows.
:param model: the model to use with the trainer and apply sparsification to
:param model_state_path: the state path to the model,
used to load config and tokenizer settings
:param recipe: the recipe, if any, to apply to the model and training
process
:param recipe_args: A json string, csv key=value string, or dictionary containing
arguments to override the root arguments within the recipe such as
learning rate or num epochs
:param teacher: teacher model for distillation. Set to 'self' to distill
from the loaded model or 'disable' to turn off distillation
:param kwargs: key word arguments passed to the parent class
A transformers trainer class with custom behavior that can be shared
by all trainers inside SparseML
"""

def __init__(
self,
model: Module,
model_state_path: str,
recipe: Optional[str],
recipe_args: Optional[Union[Dict[str, Any], str]] = None,
teacher: Optional[Union[Module, str]] = None,
**kwargs,
):
super().__init__(
model=model,
model_state_path=model_state_path,
recipe=recipe,
recipe_args=recipe_args,
teacher=teacher,
**kwargs,
)
def _save_checkpoint(self, model, trial, metrics=None):
# Call into the save checkpoint by HF Transformers, which saves the
# best metric if required
super()._save_checkpoint(model, trial, metrics=metrics)
if (
self.args.metric_for_best_model is None
or self.args.best_model_after_epoch is None
):
return

if self.state.epoch <= self.args.best_model_after_epoch:
self.state.best_metric = None
self.state.best_model_checkpoint = None

def _remove_unused_columns(
self, dataset: "datasets.Dataset", description: Optional[str] = None
Expand Down Expand Up @@ -926,6 +913,41 @@ def _remove_unused_columns(
return super()._remove_unused_columns(dataset, description)


class Trainer(TrainerInterface, TransformersTrainer):
"""
Training implementation for running sparsification recipes with transformers flows.
:param model: the model to use with the trainer and apply sparsification to
:param model_state_path: the state path to the model,
used to load config and tokenizer settings
:param recipe: the recipe, if any, to apply to the modle and training
process
:param recipe_args: A json string, csv key=value string, or dictionary containing
arguments to override the root arguments within the recipe such as
learning rate or num epochs
:param teacher: teacher model for distillation. Set to 'self' to distill
from the loaded model or 'disable' to turn of distillation
:param kwargs: key word arguments passed to the parent class
"""

def __init__(
self,
model: Module,
model_state_path: str,
recipe: Optional[str],
recipe_args: Optional[Union[Dict[str, Any], str]] = None,
teacher: Optional[Union[Module, str]] = None,
**kwargs,
):
super().__init__(
model=model,
model_state_path=model_state_path,
recipe=recipe,
recipe_args=recipe_args,
teacher=teacher,
**kwargs,
)


class DisableHalfPrecisionCallback(TrainerCallback):
"""
TrainerCallback for disabling FP16 training before QAT training begins
Expand Down
57 changes: 57 additions & 0 deletions src/sparseml/transformers/sparsification/training_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, field
from typing import Optional

from transformers import TrainingArguments as HFTrainingArgs


__all__ = ["TrainingArguments"]


@dataclass
class TrainingArguments(HFTrainingArgs):
"""
Training arguments specific to SparseML Transformers workflow

:param best_model_after_epoch (`int`, *optional*, defaults to None):
The epoch after which best model will be saved; used in conjunction
with `load_best_model_at_end` and `metric_for_best_model` training
arguments
"""

distill_teacher: Optional[str] = field(
default=None,
metadata={
"help": "Teacher model (a trained text classification model)",
},
)
best_model_after_epoch: int = field(
default=None,
metadata={"help": "Epoch after which best model will be saved."},
)
recipe: Optional[str] = field(
default=None,
metadata={
"help": (
"Path to a SparseML sparsification recipe, see "
"https://github.com/neuralmagic/sparseml for more information"
),
},
)
recipe_args: Optional[str] = field(
default=None,
metadata={"help": "Recipe arguments to be overwritten"},
)
Loading