Skip to content
This repository was archived by the owner on Jun 4, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions examples/pytorch/question-answering/run_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,6 @@ class ModelArguments:
distill_teacher: Optional[str] = field(
default=None, metadata={"help": "Teacher model which needs to be a trained QA model"}
)
distill_temperature: Optional[float] = field(
default=2.0, metadata={"help": "Temperature applied to teacher softmax for distillation."}
)
distill_hardness: Optional[float] = field(
default=1.0, metadata={"help": "Proportion of loss coming from teacher model."}
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
Expand Down Expand Up @@ -593,8 +587,6 @@ def compute_metrics(p: EvalPrediction):
model_args.model_name_or_path,
[existing_recipe, new_recipe],
teacher=teacher_model,
distill_hardness=model_args.distill_hardness,
distill_temperature=model_args.distill_temperature,
model=model,
args=training_args,
train_dataset=train_dataset if training_args.do_train else None,
Expand Down
60 changes: 18 additions & 42 deletions examples/pytorch/question-answering/sparseml_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from typing import Any

import numpy
import torch
import torch.nn.functional as F

from sparseml.pytorch.utils import ModuleExporter
from trainer_qa import QuestionAnsweringTrainer
Expand All @@ -28,46 +26,24 @@ def compute_loss(self, model, inputs, return_outputs=False):
if not self.recipes or self.teacher is None:
return super().compute_loss(model, inputs, return_outputs=return_outputs)

outputs = model(**inputs)
if self.teacher is None:
loss = outputs["loss"]
else:
input_device = inputs["input_ids"].device
self.teacher = self.teacher.to(input_device)
start_logits_student = outputs["start_logits"]
end_logits_student = outputs["end_logits"]
start_logits_label = inputs["start_positions"]
end_logits_label = inputs["end_positions"]
with torch.no_grad():
teacher_output = self.teacher(
input_ids=inputs["input_ids"],
token_type_ids=inputs["token_type_ids"],
attention_mask=inputs["attention_mask"],
)
start_logits_teacher = teacher_output["start_logits"]
end_logits_teacher = teacher_output["end_logits"]
loss_start = (
F.kl_div(
input=F.log_softmax(start_logits_student / self.distill_temperature, dim=-1),
target=F.softmax(start_logits_teacher / self.distill_temperature, dim=-1),
reduction="batchmean",
)
* (self.distill_temperature ** 2)
)
loss_end = (
F.kl_div(
input=F.log_softmax(end_logits_student / self.distill_temperature, dim=-1),
target=F.softmax(end_logits_teacher / self.distill_temperature, dim=-1),
reduction="batchmean",
)
* (self.distill_temperature ** 2)
)
teacher_loss = (loss_start + loss_end) / 2.0
loss_start = self.criterion(start_logits_student, start_logits_label)
loss_end = self.criterion(end_logits_student, end_logits_label)
label_loss = (loss_start + loss_end) / 2.0
loss = ((1 - self.distill_hardness) * label_loss) + (self.distill_hardness * teacher_loss)
return (loss, outputs) if return_outputs else loss
student_outputs = model(**inputs)
loss = student_outputs["loss"]

teacher_input_keys = ["input_ids", "token_type_ids", "attention_mask"]
teacher_inputs = {k: inputs[k] for k in teacher_input_keys}

steps_in_epoch = -1 # Unused
loss = self.manager.loss_update(
loss,
model,
self.optimizer,
self.state.epoch,
steps_in_epoch,
global_step=self.state.global_step,
student_outputs=student_outputs,
teacher_inputs=teacher_inputs,
)
return (loss, student_outputs) if return_outputs else loss


class QuestionAnsweringModuleExporter(ModuleExporter):
Expand Down
10 changes: 4 additions & 6 deletions src/transformers/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,13 @@ class SparseMLTrainer(Trainer):
:param args, kwargs: arguments passed into parent class
"""

def __init__(
self, model_name_or_path, recipes, teacher=None, distill_hardness=0.5, distill_temperature=2.0, *args, **kwargs
):
def __init__(self, model_name_or_path, recipes, teacher=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model_name_or_path = str(model_name_or_path)
self.recipes = [recipe for recipe in recipes if recipe]
self.teacher = teacher
self.distill_hardness = distill_hardness
self.distill_temperature = distill_temperature
if self.teacher is not None:
self.teacher.eval()
self.criterion = torch.nn.CrossEntropyLoss()

manager = None
Expand All @@ -57,7 +55,7 @@ def apply_recipes(self, epoch=0.0):
"""
if self.manager is not None:
org_state_dict = self.model.state_dict()
self.manager.initialize(self.model, epoch=epoch, loggers=self.loggers)
self.manager.initialize(self.model, epoch=epoch, distillation_teacher=self.teacher, loggers=self.loggers)
new_state_dict = self.model.state_dict()
new_params = [p for p in new_state_dict.keys() if p not in org_state_dict]

Expand Down