diff --git a/src/sparseml/transformers/__init__.py b/src/sparseml/transformers/__init__.py index 64a21d7740b..3fca56e1949 100644 --- a/src/sparseml/transformers/__init__.py +++ b/src/sparseml/transformers/__init__.py @@ -34,7 +34,7 @@ _LOGGER = _logging.getLogger(__name__) _NM_TRANSFORMERS_TAR_TEMPLATE = ( "https://github.com/neuralmagic/transformers/releases/download/" - "{version}/transformers-4.18.0.dev0-py3-none-any.whl" + "{version}/transformers-4.23.1-py3-none-any.whl" ) _NM_TRANSFORMERS_NIGHTLY = _NM_TRANSFORMERS_TAR_TEMPLATE.format(version="nightly") diff --git a/src/sparseml/transformers/sparsification/question_answering.py b/src/sparseml/transformers/sparsification/question_answering.py index d2cd4778d2d..a681122b5d0 100644 --- a/src/sparseml/transformers/sparsification/question_answering.py +++ b/src/sparseml/transformers/sparsification/question_answering.py @@ -80,9 +80,9 @@ def evaluate( eval_examples = self.eval_examples if eval_examples is None else eval_examples # Always evaluate w/ fp32 to be closer to DeepSparse - use_amp = self.use_amp + use_cuda_amp = self.use_cuda_amp if not self.args.fp16_full_eval and not self.args.bf16_full_eval: - self.use_amp = False + self.use_cuda_amp = False # Temporarily disable metric computation, we will do it in the loop here. compute_metrics = self.compute_metrics @@ -129,7 +129,7 @@ def evaluate( self.args, self.state, self.control, metrics ) - self.use_amp = use_amp + self.use_cuda_amp = use_cuda_amp return metrics diff --git a/src/sparseml/transformers/sparsification/trainer.py b/src/sparseml/transformers/sparsification/trainer.py index 03433e9d78a..18aa8b47f36 100644 --- a/src/sparseml/transformers/sparsification/trainer.py +++ b/src/sparseml/transformers/sparsification/trainer.py @@ -638,11 +638,16 @@ def _reload_model_state(self, load_path: str, orig_state_dict: Dict[str, Any]): return # change in keys due to architecture changes, reload statedict - load_state_dict = torch.load( + loaded_state_dict = torch.load( os.path.join(load_path, WEIGHTS_NAME), map_location="cpu" ) - _, missing, unexpected, _, _ = self.model._load_state_dict_into_model( - self.model, load_state_dict, load_path, _fast_init=False + _, missing, unexpected, _, _ = self.model._load_pretrained_model( + model=self.model, + state_dict=loaded_state_dict, + loaded_keys=list(loaded_state_dict.keys()), + resolved_archive_file=[], + pretrained_model_name_or_path=load_path, + _fast_init=False, ) if missing: @@ -803,12 +808,12 @@ def evaluate(self, *args, **kwargs): applied = self.apply_manager(epoch=math.inf, checkpoint=None) # Always evaluate w/ fp32 to be closer to DeepSparse - use_amp = self.use_amp + use_cuda_amp = self.use_cuda_amp if not self.args.fp16_full_eval and not self.args.bf16_full_eval: - self.use_amp = False + self.use_cuda_amp = False output = super().evaluate(*args, **kwargs) - self.use_amp = use_amp + self.use_cuda_amp = use_cuda_amp if applied: self.finalize_manager() @@ -901,7 +906,7 @@ def save_optimizer_and_scheduler(self, output_dir: Optional[str] = None): os.path.join(output_dir, "scheduler.pt"), ) reissue_pt_warnings(caught_warnings) - if self.use_amp: + if self.use_cuda_amp: torch.save( self.scaler.state_dict(), os.path.join(output_dir, "scaler.pt") ) @@ -1038,7 +1043,7 @@ def disable_amp(self, epoch: float): if not self.on_begin_called: # disable if training loops haven't started so we don't load # the empty scaler state dict and instead disable it from the start - self.trainer.use_amp = False + self.trainer.use_cuda_amp = False if hasattr(self.trainer, "scaler"): self.trainer.scaler._enabled = False diff --git a/tests/integrations/transformers/args.py b/tests/integrations/transformers/args.py index e9fd4e687ff..a7fa5f32aa4 100644 --- a/tests/integrations/transformers/args.py +++ b/tests/integrations/transformers/args.py @@ -513,13 +513,6 @@ class _TransformersTrainArgs(BaseModel): description="Used by the SageMaker launcher to send mp-specific args. " "Ignored in Trainer", ) - modifier_log_frequency: float = Field( - default=0.1, - description=( - "How often to log SparseML modifier data, in number of epochs or fraction " - "of epochs" - ), - ) class QuestionAnsweringArgs(_TransformersTrainArgs):