Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions src/transformers/integrations/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""
Integration with Deepspeed
"""

import copy
import importlib.metadata as importlib_metadata
import importlib.util
import weakref
Expand All @@ -27,7 +27,6 @@
if is_torch_available():
import torch

from ..optimization import get_scheduler

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -311,12 +310,15 @@ def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps
if isinstance(optimizer, DummyOptim):

def _lr_scheduler_callable(optimizer):
return get_scheduler(
trainer.args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=trainer.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
# create a shallow copy first, so later modifications do not affect original trainer
trainer_copy = copy.copy(trainer)
# at the time _lr_scheduler_callable is called, trainer.lr_scheduler has been set
# update it to None so that we can re-create a new scheduler
trainer_copy.lr_scheduler = None
lr_scheduler = trainer_copy.create_scheduler(
num_training_steps=num_training_steps, optimizer=optimizer
)
return lr_scheduler

lr_scheduler = DummyScheduler(optimizer, lr_scheduler_callable=_lr_scheduler_callable)
else:
Expand Down