Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

3. Unpin specific versions from setup.py that use a git install.

4. Checkout the release branch (v<RELEASE>-release, for example v4.19-release), and commit these changes with the
4. Checkout the release branch (v<RELEASE>-release, for example v4.19-release), and commit these changes with the
message: "Release: <VERSION>" and push.

5. Wait for the tests on main to be completed and be green (otherwise revert and fix bugs)
Expand Down Expand Up @@ -103,7 +103,7 @@
"cookiecutter==1.7.3",
"dataclasses",
"datasets",
"deepspeed>=0.6.4",
"deepspeed>=0.6.5",
"dill<0.3.5",
"fairscale>0.3",
"faiss-cpu",
Expand Down
27 changes: 11 additions & 16 deletions src/transformers/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,18 +364,6 @@ def _lr_scheduler_callable(optimizer):
return optimizer, lr_scheduler


def deepspeed_reinit(trainer):
"""
this is a temp hack based on: https://github.com/microsoft/DeepSpeed/issues/1394#issuecomment-937405374 until
Deepspeed fixes a bug where it can't resume from a checkpoint after it did some stepping
https://github.com/microsoft/DeepSpeed/issues/1612
"""
import deepspeed

deepspeed_engine, optimizer, _, lr_scheduler = deepspeed.initialize(**trainer.deepspeed_initialize_kwargs)
return deepspeed_engine, optimizer, lr_scheduler


def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inference=False):
"""
Init DeepSpeed, after updating the DeepSpeed configuration with any relevant Trainer's args.
Expand All @@ -390,15 +378,24 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inf

Returns: model, optimizer, lr_scheduler

We may use `deepspeed_init` more than once during the life of Trainer, when we do - it's a temp hack based on:
https://github.com/microsoft/DeepSpeed/issues/1394#issuecomment-937405374 until Deepspeed fixes a bug where it
can't resume from a checkpoint after it did some stepping https://github.com/microsoft/DeepSpeed/issues/1612

"""
import deepspeed
from deepspeed.utils import logger as ds_logger

model = trainer.model
args = trainer.args

if hasattr(trainer, "hf_deepspeed_config_orig"):
hf_deepspeed_config = deepcopy(trainer.hf_deepspeed_config_orig)
else:
hf_deepspeed_config = args.hf_deepspeed_config
trainer.hf_deepspeed_config_orig = deepcopy(args.hf_deepspeed_config)

# resume config update - some bits like `model` and `num_training_steps` only become available during train
hf_deepspeed_config = args.hf_deepspeed_config
hf_deepspeed_config.trainer_config_finalize(args, model, num_training_steps)
config = hf_deepspeed_config.config

Expand All @@ -416,6 +413,7 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inf
optimizer, lr_scheduler = None, None
model_parameters = None
else:
trainer.optimizer = None # important for when deepspeed_init is used as re-init
optimizer, lr_scheduler = deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps)
model_parameters = list(filter(lambda p: p.requires_grad, model.parameters()))

Expand All @@ -432,9 +430,6 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inf

deepspeed_engine, optimizer, _, lr_scheduler = deepspeed.initialize(**kwargs)

# stash kwargs to enabled a later deepspeed_reinit
trainer.deepspeed_initialize_kwargs = kwargs

if resume_from_checkpoint is not None:

# it's possible that the user is trying to resume from model_path, which doesn't necessarily
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"cookiecutter": "cookiecutter==1.7.3",
"dataclasses": "dataclasses",
"datasets": "datasets",
"deepspeed": "deepspeed>=0.6.4",
"deepspeed": "deepspeed>=0.6.5",
"dill": "dill<0.3.5",
"fairscale": "fairscale>0.3",
"faiss-cpu": "faiss-cpu",
Expand Down
17 changes: 12 additions & 5 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
from .configuration_utils import PretrainedConfig
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
from .debug_utils import DebugOption, DebugUnderflowOverflow
from .deepspeed import deepspeed_init, deepspeed_reinit, is_deepspeed_zero3_enabled
from .deepspeed import deepspeed_init, is_deepspeed_zero3_enabled
from .dependency_versions_check import dep_version_check
from .modelcard import TrainingSummary
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
Expand Down Expand Up @@ -1749,16 +1749,23 @@ def _load_best_model(self):
model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if os.path.exists(best_model_path):
if self.deepspeed:

if self.model_wrapped is not None:
# this removes the pre-hooks from the previous engine
self.model_wrapped.destroy()
self.model_wrapped = None

# temp hack until Deepspeed fixes the problem with resume from an existing engine that did some stepping
deepspeed_engine, optimizer, lr_scheduler = deepspeed_reinit(self)
deepspeed_engine, optimizer, lr_scheduler = deepspeed_init(
self,
num_training_steps=self.args.max_steps,
resume_from_checkpoint=self.state.best_model_checkpoint,
)
self.model = deepspeed_engine.module
self.model_wrapped = deepspeed_engine
self.deepspeed = deepspeed_engine
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.deepspeed.load_checkpoint(
self.state.best_model_checkpoint, load_optimizer_states=True, load_lr_scheduler_states=True
)
else:
# We load the model state dict on the CPU to avoid an OOM error.
state_dict = torch.load(best_model_path, map_location="cpu")
Expand Down
184 changes: 115 additions & 69 deletions tests/deepspeed/test_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import unittest
from copy import deepcopy

import datasets

from parameterized import parameterized
from tests.trainer.test_trainer import TrainerIntegrationCommon # noqa
from transformers import AutoModel, TrainingArguments, is_torch_available, logging
Expand Down Expand Up @@ -195,28 +197,7 @@ def test_init_zero3_fp16(self):
self.assertNotIn("Detected DeepSpeed ZeRO-3", cl.out)


@require_deepspeed
@require_torch_gpu
class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
"""

This class is for testing directly via get_regression_trainer

It mixes in `TrainerIntegrationCommon` which already has a lot of helper validation methods
which we can re-use here.

Important: this class' setup can only work with a single gpu because it runs within the current
pytest worker. For multi-gpu tests use TestDeepSpeedWithLauncher.

Note: if any of the tests of this class get run there will be at least one gpu occupied by them
until this pytest worker exits. This is because the gpu memory allocated by the cuda-kernels
won't be released until this pytest worker exits.

This may appear as some run-away tests if you watch `nvidia-smi` while other tests that fork new
processes are run. So there will be one or two "stale" processes reported in `nvidia-smi`. This
is not a bug.
"""

class TrainerIntegrationDeepSpeedWithCustomConfig(TestCasePlus):
def setUp(self):
super().setUp()

Expand Down Expand Up @@ -252,6 +233,29 @@ def get_config_dict(self, stage):
# As some tests modify the dict, always make a copy
return deepcopy(self.ds_config_dict[stage])


@require_deepspeed
@require_torch_gpu
class TrainerIntegrationDeepSpeed(TrainerIntegrationDeepSpeedWithCustomConfig, TrainerIntegrationCommon):
"""

This class is for testing directly via get_regression_trainer

It mixes in `TrainerIntegrationCommon` which already has a lot of helper validation methods
which we can re-use here.

Important: this class' setup can only work with a single gpu because it runs within the current
pytest worker. For multi-gpu tests use TestDeepSpeedWithLauncher.

Note: if any of the tests of this class get run there will be at least one gpu occupied by them
until this pytest worker exits. This is because the gpu memory allocated by the cuda-kernels
won't be released until this pytest worker exits.

This may appear as some run-away tests if you watch `nvidia-smi` while other tests that fork new
processes are run. So there will be one or two "stale" processes reported in `nvidia-smi`. This
is not a bug.
"""

# --- These tests are enough to run on one of zero stages --- #

def test_hf_ds_config_mismatch(self):
Expand Down Expand Up @@ -725,6 +729,95 @@ def test_config_object(self):
self.assertFalse(is_deepspeed_zero3_enabled())
self.assertFalse(bool(config), "Deepspeed config should not be accessible")

@parameterized.expand(params, name_func=parameterized_custom_name_func)
def test_load_best_model(self, stage, dtype):
# Test that forced deepspeed reinit doesn't break the model. the forced re-init after
# loading the best model in Trainer is there to workaround this bug in Deepspeed
# https://github.com/microsoft/DeepSpeed/issues/1612
#
# The test is derived from a repro script submitted in this Issue:
# https://github.com/huggingface/transformers/issues/17114
#
# One additional feature of this test is that we use a non-AdamW optimizer to test that
# deepspeed doesn't fallback to AdamW, which would prevent the optimizer states from loading
# correctly

from transformers import T5ForConditionalGeneration, T5Tokenizer, Trainer # noqa

output_dir = self.get_auto_remove_tmp_dir() # "./xxx", after=False, before=False)

ds_config_dict = self.get_config_dict(stage)
del ds_config_dict["optimizer"] # will use HF Trainer optimizer
del ds_config_dict["scheduler"] # will use HF Trainer scheduler
# must use this setting to get the reload path exercised
ds_config_dict["zero_optimization"]["stage3_gather_16bit_weights_on_model_save"] = True

tokenizer = T5Tokenizer.from_pretrained(T5_TINY)
model = T5ForConditionalGeneration.from_pretrained(T5_TINY)

def _add_eos_to_examples(example):
example["input_text"] = f"question: {example['question']} context: {example['context']}"
example["target_text"] = example["answers"]["text"][0] if len(example["answers"]["text"]) > 0 else ""
return example

def _convert_to_features(example_batch):
input_encodings = tokenizer.batch_encode_plus(
example_batch["input_text"], pad_to_max_length=True, max_length=512, truncation=True
)
target_encodings = tokenizer.batch_encode_plus(
example_batch["target_text"], pad_to_max_length=True, max_length=16, truncation=True
)

encodings = {
"input_ids": input_encodings["input_ids"],
"attention_mask": input_encodings["attention_mask"],
"labels": target_encodings["input_ids"],
}

return encodings

def get_dataset():
data_file = str(self.tests_dir / "fixtures/tests_samples/SQUAD/sample.json")
data_files = dict(train=data_file, validation=data_file)
raw_datasets = datasets.load_dataset("json", data_files=data_files, field="data")
train_dataset = raw_datasets["train"].map(_add_eos_to_examples).map(_convert_to_features, batched=True)
valid_dataset = deepcopy(train_dataset)
return train_dataset, valid_dataset

train_dataset, eval_dataset = get_dataset()

args_dict = {
"per_gpu_train_batch_size": 1,
"per_gpu_eval_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-4,
"num_train_epochs": 1,
"do_train": True,
"do_eval": True,
"optim": "adafactor",
"evaluation_strategy": "steps",
"eval_steps": 1,
"save_strategy": "steps",
"save_steps": 1,
"load_best_model_at_end": True,
"max_steps": 1,
"deepspeed": ds_config_dict,
}

with mockenv_context(**self.dist_env_1_gpu):

training_args = TrainingArguments(output_dir, **args_dict)

trainer = Trainer(
model=model,
tokenizer=tokenizer,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
trainer.train() # crash 1 was here
trainer.evaluate() # crash 2 was here


@slow
@require_deepspeed
Expand Down Expand Up @@ -1035,50 +1128,3 @@ def test_clm_from_config_zero3_fp16(self):
with CaptureStderr() as cs:
execute_subprocess_async(cmd, env=self.get_env())
self.assertIn("Detected DeepSpeed ZeRO-3", cs.err)

@parameterized.expand(params, name_func=parameterized_custom_name_func)
def test_load_best_model(self, stage, dtype):
# this test exercises --load_best_model_at_end - the key is being able to resume after some training

data_dir = self.tests_dir / "fixtures/tests_samples/wmt_en_ro"
output_dir = self.get_auto_remove_tmp_dir()
args = f"""
--model_name_or_path {T5_TINY}
--tokenizer_name {T5_TINY}
--train_file {data_dir}/train.json
--validation_file {data_dir}/val.json
--output_dir {output_dir}
--overwrite_output_dir
--source_lang en
--target_lang ro
--do_train
--max_train_samples 3
--do_eval
--max_eval_samples 1
--logging_strategy steps
--logging_steps 1
--evaluation_strategy steps
--eval_steps 1
--save_strategy steps
--save_steps 1
--load_best_model_at_end
--per_device_train_batch_size 1
--per_device_eval_batch_size 1
--num_train_epochs 1
--report_to none
""".split()
args.extend(["--source_prefix", "translate English to Romanian: "])

args.extend([f"--{dtype}"])

ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config_{stage}.json".split()
script = [f"{self.examples_dir_str}/pytorch/translation/run_translation.py"]
launcher = get_launcher(distributed=False)

cmd = launcher + script + args + ds_args
# keep for quick debug
# print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die
with CaptureStd() as cs:
execute_subprocess_async(cmd, env=self.get_env())
# enough to test it didn't fail
self.assertIn("DeepSpeed info", cs.out)