Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1465,9 +1465,13 @@ def _from_config(cls, config, **kwargs):
# and memory copying it on CPU or each GPU first
with deepspeed.zero.Init(config_dict_or_path=deepspeed_config()):
model = cls(config, **kwargs)

else:
model = cls(config, **kwargs)

# Flag for if we init with `zero3`, add an attr to the model so we can check downstream for issues
model._transformers_zero3_init_used = is_deepspeed_zero3_enabled()

# restore default dtype if it was modified
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)
Expand Down Expand Up @@ -3797,6 +3801,9 @@ def from_pretrained(
# Let's make sure we don't run the init function of buffer modules
model = cls(config, *model_args, **model_kwargs)

# If we init with `zero3`, add an attr to the model so we can check downstream for issues
model._transformers_zero3_init_used = is_deepspeed_zero3_enabled() and not is_quantized

# make sure we use the model's config since the __init__ call might have copied it
config = model.config

Expand Down
10 changes: 10 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
get_model_param_count,
get_module_class_from_name,
get_parameter_names,
is_deepspeed_zero3_enabled,
nested_concat,
nested_detach,
nested_numpify,
Expand Down Expand Up @@ -435,6 +436,15 @@ def __init__(
)
self.model_init = model_init

# Will reach this branch if the user has
# 1. Used `.from_pretrained` or `.from_config` to initialize their model
# 2. Did not configure Zero-3 via `TrainingArguments` or `accelerate launch` beforehand
# New models init such as `MyModel()` will not hit this step
if is_deepspeed_zero3_enabled() and not getattr(model, "_transformers_zero3_init_used", True):
raise ValueError(
"Model was not initialized with `Zero-3` despite being configured for DeepSpeed Zero-3. Please re-initialize your model via `Model.from_pretrained(...)` or `Model.from_config(...)` after creating your `TrainingArguments`!"
)

if model.__class__.__name__ in MODEL_MAPPING_NAMES:
raise ValueError(
f"The model you have picked ({model.__class__.__name__}) cannot be used as is for training: it only "
Expand Down
25 changes: 25 additions & 0 deletions tests/deepspeed/test_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,31 @@ def test_gradient_accumulation(self, stage, dtype):
# Relative difference. See the note above how to get identical loss on a small bs
self.assertTrue((no_grad_accum_loss - yes_grad_accum_loss) / (no_grad_accum_loss + 1e-15) <= 1e-3)

def test_missed_zero3_init(self):
from transformers import Trainer # noqa

with mockenv_context(**self.dist_env_1_gpu):
model = AutoModel.from_pretrained(T5_TINY)
training_args = TrainingArguments(
output_dir="./test_missed_zero3_init",
deepspeed=self.get_config_dict(ZERO3),
)
with self.assertRaises(
ValueError, msg="Model was not initialized with `Zero-3` despite being configured."
):
_ = Trainer(
model=model,
args=training_args,
)
# Now do it properly, triggered from our `TrainingArguments` earlier
model = AutoModel.from_pretrained(T5_TINY)
trainer = Trainer(
model=model,
args=training_args,
)
assert trainer.is_deepspeed_enabled
assert model._transformers_zero3_init_used

def check_saved_checkpoints_deepspeed(self, output_dir, freq, total, stage, dtype):
# adapted from TrainerIntegrationCommon.check_saved_checkpoints
file_list = [SAFE_WEIGHTS_NAME, "training_args.bin", "trainer_state.json", "config.json"]
Expand Down