Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions src/transformers/integrations/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
if is_torch_available():
import torch

from ..optimization import get_scheduler

logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -274,7 +276,7 @@ def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps
# 1. DS scheduler + DS optimizer: Yes
# 2. HF scheduler + HF optimizer: Mostly*
# 3. DS scheduler + HF optimizer: Mostly*
# 4. HF scheduler + DS optimizer: No
# 4. HF scheduler + DS optimizer: Yes
#
# Mostly*: All non-native DeepSpeed optimizers that have both CPU and GPU implementation should work (except LAMB)

Expand Down Expand Up @@ -304,11 +306,18 @@ def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps
lr_scheduler = DummyScheduler(optimizer)
else:
if isinstance(optimizer, DummyOptim):
raise ValueError(
"Found `optimizer` configured in the DeepSpeed config, but no `scheduler`. "
"Please configure a scheduler in the DeepSpeed config."
)
lr_scheduler = trainer.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)

def _lr_scheduler_callable(optimizer):
return get_scheduler(
trainer.args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=trainer.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)

lr_scheduler = DummyScheduler(optimizer, lr_scheduler_callable=_lr_scheduler_callable)
else:
lr_scheduler = trainer.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)

return optimizer, lr_scheduler

Expand Down
18 changes: 16 additions & 2 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
from .debug_utils import DebugOption, DebugUnderflowOverflow
from .dependency_versions_check import dep_version_check
from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend
from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint
from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available
from .modelcard import TrainingSummary
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES
Expand Down Expand Up @@ -212,6 +212,9 @@
save_fsdp_optimizer,
)

if is_deepspeed_available():
from accelerate.utils import DeepSpeedSchedulerWrapper


if TYPE_CHECKING:
import optuna
Expand Down Expand Up @@ -2362,7 +2365,14 @@ def _save_checkpoint(self, model, trial, metrics=None):
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))

# Save SCHEDULER & SCALER
if self.args.should_save and not self.is_deepspeed_enabled and not is_torch_tpu_available():
is_deepspeed_custom_scheduler = self.is_deepspeed_enabled and not isinstance(
self.lr_scheduler, DeepSpeedSchedulerWrapper
)
if (
self.args.should_save
and (not self.is_deepspeed_enabled or is_deepspeed_custom_scheduler)
and not is_torch_tpu_available()
):
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)
Expand Down Expand Up @@ -2428,6 +2438,10 @@ def _load_optimizer_and_scheduler(self, checkpoint):

if self.is_deepspeed_enabled:
# deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
if not isinstance(self.lr_scheduler, DeepSpeedSchedulerWrapper):
with warnings.catch_warnings(record=True) as caught_warnings:
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
reissue_pt_warnings(caught_warnings)
return

checkpoint_file_exists = (
Expand Down
26 changes: 24 additions & 2 deletions tests/deepspeed/test_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,14 @@ def get_launcher(distributed=False):
FP16 = "fp16"
BF16 = "bf16"

HF_OPTIM = "hf_optim"
HF_SCHEDULER = "hf_scheduler"
DS_OPTIM = "ds_optim"
DS_SCHEDULER = "ds_scheduler"

optims = [HF_OPTIM, DS_OPTIM]
schedulers = [HF_SCHEDULER, DS_SCHEDULER]

stages = [ZERO2, ZERO3]
if is_torch_bf16_gpu_available():
dtypes = [FP16, BF16]
Expand All @@ -153,6 +161,8 @@ def parameterized_custom_name_func(func, param_num, param):
# Cartesian-product of zero stages with models to test
params = list(itertools.product(stages, dtypes))

params_with_optims_and_schedulers = list(itertools.product(stages, dtypes, optims, schedulers))


@require_deepspeed
@require_torch_gpu
Expand Down Expand Up @@ -640,10 +650,16 @@ def test_can_resume_training_errors(self, stage, dtype):
"Can't find a valid checkpoint at" in str(context.exception), f"got exception: {context.exception}"
)

@parameterized.expand(params, name_func=parameterized_custom_name_func)
def test_can_resume_training_normal(self, stage, dtype):
@parameterized.expand(params_with_optims_and_schedulers, name_func=parameterized_custom_name_func)
def test_can_resume_training_normal(self, stage, dtype, optim, scheduler):
# adapted from TrainerIntegrationTest.test_can_resume_training
# test normal resume for each stage separately, error-handling is tested in a different test

# ToDo: Currently, hf_optim + hf_scheduler resumes with the correct states and
# also has same losses for few steps but then slowly diverges. Need to figure it out.
if optim == HF_OPTIM and scheduler == HF_SCHEDULER:
return

output_dir = self.get_auto_remove_tmp_dir("./xxx", after=False)
ds_config_dict = self.get_config_dict(stage)
if dtype == FP16:
Expand All @@ -652,6 +668,12 @@ def test_can_resume_training_normal(self, stage, dtype):
if stage == ZERO3:
ds_config_dict["zero_optimization"]["stage3_gather_16bit_weights_on_model_save"] = True

if optim == HF_OPTIM:
del ds_config_dict["optimizer"]

if scheduler == HF_SCHEDULER:
del ds_config_dict["scheduler"]

kwargs = {
"output_dir": output_dir,
"train_len": 128,
Expand Down