diff --git a/src/transformers/integrations/deepspeed.py b/src/transformers/integrations/deepspeed.py index 101610af555f..d5375e2d2ab3 100644 --- a/src/transformers/integrations/deepspeed.py +++ b/src/transformers/integrations/deepspeed.py @@ -14,7 +14,7 @@ """ Integration with Deepspeed """ - +import copy import importlib.metadata as importlib_metadata import importlib.util import weakref @@ -27,7 +27,6 @@ if is_torch_available(): import torch - from ..optimization import get_scheduler logger = logging.get_logger(__name__) @@ -311,12 +310,15 @@ def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps if isinstance(optimizer, DummyOptim): def _lr_scheduler_callable(optimizer): - return get_scheduler( - trainer.args.lr_scheduler_type, - optimizer=optimizer, - num_warmup_steps=trainer.args.get_warmup_steps(num_training_steps), - num_training_steps=num_training_steps, + # create a shallow copy first, so later modifications do not affect original trainer + trainer_copy = copy.copy(trainer) + # at the time _lr_scheduler_callable is called, trainer.lr_scheduler has been set + # update it to None so that we can re-create a new scheduler + trainer_copy.lr_scheduler = None + lr_scheduler = trainer_copy.create_scheduler( + num_training_steps=num_training_steps, optimizer=optimizer ) + return lr_scheduler lr_scheduler = DummyScheduler(optimizer, lr_scheduler_callable=_lr_scheduler_callable) else: