diff --git a/src/transformers/sparse.py b/src/transformers/sparse.py index 2c5006ae73e2..60f120e44cbb 100644 --- a/src/transformers/sparse.py +++ b/src/transformers/sparse.py @@ -120,6 +120,14 @@ def create_scheduler(self, num_training_steps: int): # default scheduler super().create_scheduler(num_training_steps) + def qat_active(self, epoch: int): + if not self.manager.quantization_modifiers: + return False + + qat_start = min([mod.start_epoch for mod in self.manager.quantization_modifiers]) + + return qat_start < epoch + 1 + def save_model(self, output_dir: Optional[str] = None): """ Save model during or after training. The sparsification recipe will also be saved. diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index fd1a03930734..b299a96744b8 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1215,6 +1215,10 @@ def train( break for epoch in range(epochs_trained, num_train_epochs): + if self.use_amp and hasattr(self, "qat_active") and callable(self.qat_active) and self.qat_active(epoch): + logger.info("entering QAT phase, disabling FP16 training") + self.scaler._enabled = False + if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler): train_dataloader.sampler.set_epoch(epoch) elif isinstance(train_dataloader.dataset, IterableDatasetShard): @@ -1732,7 +1736,7 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, return loss_mb.reduce_mean().detach().to(self.args.device) if self.use_amp: - with autocast(): + with autocast(enabled=self.scaler.is_enabled()): loss = self.compute_loss(model, inputs) else: loss = self.compute_loss(model, inputs) @@ -2377,7 +2381,7 @@ def prediction_step( else: loss = None if self.use_amp: - with autocast(): + with autocast(enabled=self.scaler.is_enabled()): outputs = model(**inputs) else: outputs = model(**inputs)