diff --git a/src/transformers/integrations/deepspeed.py b/src/transformers/integrations/deepspeed.py index efeccb85c246..abdfd98e37ef 100644 --- a/src/transformers/integrations/deepspeed.py +++ b/src/transformers/integrations/deepspeed.py @@ -26,6 +26,8 @@ if is_torch_available(): import torch + from ..optimization import get_scheduler + logger = logging.get_logger(__name__) @@ -274,7 +276,7 @@ def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps # 1. DS scheduler + DS optimizer: Yes # 2. HF scheduler + HF optimizer: Mostly* # 3. DS scheduler + HF optimizer: Mostly* - # 4. HF scheduler + DS optimizer: No + # 4. HF scheduler + DS optimizer: Yes # # Mostly*: All non-native DeepSpeed optimizers that have both CPU and GPU implementation should work (except LAMB) @@ -304,11 +306,18 @@ def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps lr_scheduler = DummyScheduler(optimizer) else: if isinstance(optimizer, DummyOptim): - raise ValueError( - "Found `optimizer` configured in the DeepSpeed config, but no `scheduler`. " - "Please configure a scheduler in the DeepSpeed config." - ) - lr_scheduler = trainer.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer) + + def _lr_scheduler_callable(optimizer): + return get_scheduler( + trainer.args.lr_scheduler_type, + optimizer=optimizer, + num_warmup_steps=trainer.args.get_warmup_steps(num_training_steps), + num_training_steps=num_training_steps, + ) + + lr_scheduler = DummyScheduler(optimizer, lr_scheduler_callable=_lr_scheduler_callable) + else: + lr_scheduler = trainer.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer) return optimizer, lr_scheduler diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 45e858feca18..20f82596b92f 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -60,7 +60,7 @@ from .debug_utils import DebugOption, DebugUnderflowOverflow from .dependency_versions_check import dep_version_check from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend -from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint +from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available from .modelcard import TrainingSummary from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES @@ -212,6 +212,9 @@ save_fsdp_optimizer, ) + if is_deepspeed_available(): + from accelerate.utils import DeepSpeedSchedulerWrapper + if TYPE_CHECKING: import optuna @@ -2362,7 +2365,14 @@ def _save_checkpoint(self, model, trial, metrics=None): torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) # Save SCHEDULER & SCALER - if self.args.should_save and not self.is_deepspeed_enabled and not is_torch_tpu_available(): + is_deepspeed_custom_scheduler = self.is_deepspeed_enabled and not isinstance( + self.lr_scheduler, DeepSpeedSchedulerWrapper + ) + if ( + self.args.should_save + and (not self.is_deepspeed_enabled or is_deepspeed_custom_scheduler) + and not is_torch_tpu_available() + ): with warnings.catch_warnings(record=True) as caught_warnings: torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) reissue_pt_warnings(caught_warnings) @@ -2428,6 +2438,10 @@ def _load_optimizer_and_scheduler(self, checkpoint): if self.is_deepspeed_enabled: # deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init + if not isinstance(self.lr_scheduler, DeepSpeedSchedulerWrapper): + with warnings.catch_warnings(record=True) as caught_warnings: + self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME))) + reissue_pt_warnings(caught_warnings) return checkpoint_file_exists = ( diff --git a/tests/deepspeed/test_deepspeed.py b/tests/deepspeed/test_deepspeed.py index 2fa1caf0b5ca..3f8ca1033213 100644 --- a/tests/deepspeed/test_deepspeed.py +++ b/tests/deepspeed/test_deepspeed.py @@ -136,6 +136,14 @@ def get_launcher(distributed=False): FP16 = "fp16" BF16 = "bf16" +HF_OPTIM = "hf_optim" +HF_SCHEDULER = "hf_scheduler" +DS_OPTIM = "ds_optim" +DS_SCHEDULER = "ds_scheduler" + +optims = [HF_OPTIM, DS_OPTIM] +schedulers = [HF_SCHEDULER, DS_SCHEDULER] + stages = [ZERO2, ZERO3] if is_torch_bf16_gpu_available(): dtypes = [FP16, BF16] @@ -153,6 +161,8 @@ def parameterized_custom_name_func(func, param_num, param): # Cartesian-product of zero stages with models to test params = list(itertools.product(stages, dtypes)) +params_with_optims_and_schedulers = list(itertools.product(stages, dtypes, optims, schedulers)) + @require_deepspeed @require_torch_gpu @@ -640,10 +650,16 @@ def test_can_resume_training_errors(self, stage, dtype): "Can't find a valid checkpoint at" in str(context.exception), f"got exception: {context.exception}" ) - @parameterized.expand(params, name_func=parameterized_custom_name_func) - def test_can_resume_training_normal(self, stage, dtype): + @parameterized.expand(params_with_optims_and_schedulers, name_func=parameterized_custom_name_func) + def test_can_resume_training_normal(self, stage, dtype, optim, scheduler): # adapted from TrainerIntegrationTest.test_can_resume_training # test normal resume for each stage separately, error-handling is tested in a different test + + # ToDo: Currently, hf_optim + hf_scheduler resumes with the correct states and + # also has same losses for few steps but then slowly diverges. Need to figure it out. + if optim == HF_OPTIM and scheduler == HF_SCHEDULER: + return + output_dir = self.get_auto_remove_tmp_dir("./xxx", after=False) ds_config_dict = self.get_config_dict(stage) if dtype == FP16: @@ -652,6 +668,12 @@ def test_can_resume_training_normal(self, stage, dtype): if stage == ZERO3: ds_config_dict["zero_optimization"]["stage3_gather_16bit_weights_on_model_save"] = True + if optim == HF_OPTIM: + del ds_config_dict["optimizer"] + + if scheduler == HF_SCHEDULER: + del ds_config_dict["scheduler"] + kwargs = { "output_dir": output_dir, "train_len": 128,