From 7a802269dd26c521f2f0a83c997e29f1f44d3029 Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Mon, 29 Jul 2024 11:15:12 -0400 Subject: [PATCH 1/7] Test this zach --- src/transformers/modeling_utils.py | 4 ++++ src/transformers/trainer.py | 4 ++++ tests/deepspeed/test_deepspeed.py | 21 +++++++++++++++++++++ 3 files changed, 29 insertions(+) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 557624f78b66..ada42beb2506 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3797,6 +3797,10 @@ def from_pretrained( # Let's make sure we don't run the init function of buffer modules model = cls(config, *model_args, **model_kwargs) + # If we init with `zero3`, add an attr to the model so we can check downstream for issues + if is_deepspeed_zero3_enabled(): + model.transformers_zero3_init_used = True + # make sure we use the model's config since the __init__ call might have copied it config = model.config diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 25b7b6993092..085cc71cfbba 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -4678,6 +4678,10 @@ def create_accelerator_and_postprocess(self): if self.is_deepspeed_enabled and getattr(self.args, "hf_deepspeed_config", None) is None: self.propagate_args_to_deepspeed() + if self.is_deepspeed_enabled and is_deepspeed_zero3_enabled(): + if not getattr(self.model, "transformers_zero3_init_used", False): + raise ValueError("Model was not initialized with `Zero-3` despite being configured. Please re-initialize your model via AutoModel.from_pretrained(...) after creating your `TrainingArguments`!") + # `save_only_model` can't be used with DeepSpeed/FSDP along with `load_best_model_at_end` if ( self.args.save_only_model diff --git a/tests/deepspeed/test_deepspeed.py b/tests/deepspeed/test_deepspeed.py index 7b50165babf4..9c1d0af0fccb 100644 --- a/tests/deepspeed/test_deepspeed.py +++ b/tests/deepspeed/test_deepspeed.py @@ -229,6 +229,27 @@ def test_init_zero3_fp16(self): AutoModel.from_pretrained(T5_TINY) self.assertNotIn("Detected DeepSpeed ZeRO-3", cl.out) + def test_zero3_misconfigured(self): + # test that catches if a model was created before `zero.Init()` was called + AutoModel.from_pretrained(T5_TINY) + + + # Now add in zero optimization + ds_config = { + "train_batch_size": 1, + "zero_optimization": { + "stage": 3, + }, + } + + dschf = HfDeepSpeedConfig(ds_config) + + self.assertTrue(dschf.is_zero3()) + self.assertTrue(is_deepspeed_zero3_enabled()) + + with self.assertRaises(ValueError, msg="Model was not initialized with `Zero-3` despite being configured."): + AutoModel.from_pretrained(T5_TINY) + def test_init_zero3_missing_params(self): # test that zero.Init() for missing parameters works correctly under zero3 import deepspeed From f296a1baf935b04b7dec90b28ffb457dbbe9136a Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Mon, 29 Jul 2024 12:02:19 -0400 Subject: [PATCH 2/7] Test for improper init w/o zero3 --- src/transformers/modeling_utils.py | 8 +++++- src/transformers/trainer.py | 15 +++++++--- tests/deepspeed/test_deepspeed.py | 46 ++++++++++++++++-------------- tests/trainer/test_trainer.py | 12 ++++---- 4 files changed, 49 insertions(+), 32 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index ada42beb2506..5b066a753678 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1465,9 +1465,13 @@ def _from_config(cls, config, **kwargs): # and memory copying it on CPU or each GPU first with deepspeed.zero.Init(config_dict_or_path=deepspeed_config()): model = cls(config, **kwargs) + else: model = cls(config, **kwargs) + # Flag for if we init with `zero3`, add an attr to the model so we can check downstream for issues + model.transformers_zero3_init_used = is_deepspeed_zero3_enabled() + # restore default dtype if it was modified if dtype_orig is not None: torch.set_default_dtype(dtype_orig) @@ -3798,8 +3802,10 @@ def from_pretrained( model = cls(config, *model_args, **model_kwargs) # If we init with `zero3`, add an attr to the model so we can check downstream for issues - if is_deepspeed_zero3_enabled(): + if is_deepspeed_zero3_enabled() and not is_quantized: model.transformers_zero3_init_used = True + else: + model.transformers_zero3_init_used = False # make sure we use the model's config since the __init__ call might have copied it config = model.config diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 085cc71cfbba..254fc5b5f7ee 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -100,6 +100,7 @@ get_model_param_count, get_module_class_from_name, get_parameter_names, + is_deepspeed_zero3_enabled, nested_concat, nested_detach, nested_numpify, @@ -435,6 +436,16 @@ def __init__( ) self.model_init = model_init + if is_deepspeed_zero3_enabled(): + # Will reach this branch if the user has + # 1. Used `.from_pretrained` or `.from_config` to initialize their model + # 2. Did not configure Zero-3 via `TrainingArguments` or `accelerate launch` beforehand + # New models init such as `MyModel()` will not hit this step + if hasattr(model, "transformers_zero3_init_used") and not model.transformers_zero3_init_used: + raise ValueError( + "Model was not initialized with `Zero-3` despite being configured. Please re-initialize your model via `***Model.from_pretrained(...)` or `***Model.from_config(...)` after creating your `TrainingArguments`!" + ) + if model.__class__.__name__ in MODEL_MAPPING_NAMES: raise ValueError( f"The model you have picked ({model.__class__.__name__}) cannot be used as is for training: it only " @@ -4678,10 +4689,6 @@ def create_accelerator_and_postprocess(self): if self.is_deepspeed_enabled and getattr(self.args, "hf_deepspeed_config", None) is None: self.propagate_args_to_deepspeed() - if self.is_deepspeed_enabled and is_deepspeed_zero3_enabled(): - if not getattr(self.model, "transformers_zero3_init_used", False): - raise ValueError("Model was not initialized with `Zero-3` despite being configured. Please re-initialize your model via AutoModel.from_pretrained(...) after creating your `TrainingArguments`!") - # `save_only_model` can't be used with DeepSpeed/FSDP along with `load_best_model_at_end` if ( self.args.save_only_model diff --git a/tests/deepspeed/test_deepspeed.py b/tests/deepspeed/test_deepspeed.py index 9c1d0af0fccb..429db7d8f03c 100644 --- a/tests/deepspeed/test_deepspeed.py +++ b/tests/deepspeed/test_deepspeed.py @@ -229,27 +229,6 @@ def test_init_zero3_fp16(self): AutoModel.from_pretrained(T5_TINY) self.assertNotIn("Detected DeepSpeed ZeRO-3", cl.out) - def test_zero3_misconfigured(self): - # test that catches if a model was created before `zero.Init()` was called - AutoModel.from_pretrained(T5_TINY) - - - # Now add in zero optimization - ds_config = { - "train_batch_size": 1, - "zero_optimization": { - "stage": 3, - }, - } - - dschf = HfDeepSpeedConfig(ds_config) - - self.assertTrue(dschf.is_zero3()) - self.assertTrue(is_deepspeed_zero3_enabled()) - - with self.assertRaises(ValueError, msg="Model was not initialized with `Zero-3` despite being configured."): - AutoModel.from_pretrained(T5_TINY) - def test_init_zero3_missing_params(self): # test that zero.Init() for missing parameters works correctly under zero3 import deepspeed @@ -730,6 +709,31 @@ def test_gradient_accumulation(self, stage, dtype): # Relative difference. See the note above how to get identical loss on a small bs self.assertTrue((no_grad_accum_loss - yes_grad_accum_loss) / (no_grad_accum_loss + 1e-15) <= 1e-3) + def test_missed_zero3_init(self): + from transformers import Trainer # noqa + + with mockenv_context(**self.dist_env_1_gpu): + model = AutoModel.from_pretrained(T5_TINY) + training_args = TrainingArguments( + output_dir="./test_missed_zero3_init", + deepspeed=self.get_config_dict(ZERO3), + ) + with self.assertRaises( + ValueError, msg="Model was not initialized with `Zero-3` despite being configured." + ): + _ = Trainer( + model=model, + args=training_args, + ) + # Now do it proper, triggered from our `TrainingArguments` earlier + model = AutoModel.from_pretrained(T5_TINY) + trainer = Trainer( + model=model, + args=training_args, + ) + assert trainer.is_deepspeed_enabled + assert model.transformers_zero3_init_used + def check_saved_checkpoints_deepspeed(self, output_dir, freq, total, stage, dtype): # adapted from TrainerIntegrationCommon.check_saved_checkpoints file_list = [SAFE_WEIGHTS_NAME, "training_args.bin", "trainer_state.json", "config.json"] diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 7378a597c39c..88a069df007d 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -434,6 +434,12 @@ def get_regression_trainer( train_dataset = RegressionDataset(length=train_len, label_names=label_names) eval_dataset = RegressionDataset(length=eval_len, label_names=label_names) + compute_metrics = kwargs.pop("compute_metrics", None) + data_collator = kwargs.pop("data_collator", None) + optimizers = kwargs.pop("optimizers", (None, None)) + output_dir = kwargs.pop("output_dir", "./regression") + preprocess_logits_for_metrics = kwargs.pop("preprocess_logits_for_metrics", None) + model_init = kwargs.pop("model_init", None) if model_init is not None: model = None @@ -450,12 +456,6 @@ def get_regression_trainer( else: model = RegressionModel(a=a, b=b, double_output=double_output) - compute_metrics = kwargs.pop("compute_metrics", None) - data_collator = kwargs.pop("data_collator", None) - optimizers = kwargs.pop("optimizers", (None, None)) - output_dir = kwargs.pop("output_dir", "./regression") - preprocess_logits_for_metrics = kwargs.pop("preprocess_logits_for_metrics", None) - args = RegressionTrainingArguments(output_dir, a=a, b=b, keep_report_to=keep_report_to, **kwargs) return Trainer( model, From 4f9c0f098e5d04fa0557f50e858037c8e1f872e4 Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Mon, 29 Jul 2024 12:03:11 -0400 Subject: [PATCH 3/7] Move back --- tests/trainer/test_trainer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 88a069df007d..7378a597c39c 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -434,12 +434,6 @@ def get_regression_trainer( train_dataset = RegressionDataset(length=train_len, label_names=label_names) eval_dataset = RegressionDataset(length=eval_len, label_names=label_names) - compute_metrics = kwargs.pop("compute_metrics", None) - data_collator = kwargs.pop("data_collator", None) - optimizers = kwargs.pop("optimizers", (None, None)) - output_dir = kwargs.pop("output_dir", "./regression") - preprocess_logits_for_metrics = kwargs.pop("preprocess_logits_for_metrics", None) - model_init = kwargs.pop("model_init", None) if model_init is not None: model = None @@ -456,6 +450,12 @@ def get_regression_trainer( else: model = RegressionModel(a=a, b=b, double_output=double_output) + compute_metrics = kwargs.pop("compute_metrics", None) + data_collator = kwargs.pop("data_collator", None) + optimizers = kwargs.pop("optimizers", (None, None)) + output_dir = kwargs.pop("output_dir", "./regression") + preprocess_logits_for_metrics = kwargs.pop("preprocess_logits_for_metrics", None) + args = RegressionTrainingArguments(output_dir, a=a, b=b, keep_report_to=keep_report_to, **kwargs) return Trainer( model, From f348b3808ce8cd430242dde6d84bd99a04bd6b84 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Thu, 1 Aug 2024 09:45:18 -0400 Subject: [PATCH 4/7] Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- src/transformers/modeling_utils.py | 5 +---- src/transformers/trainer.py | 17 ++++++++--------- tests/deepspeed/test_deepspeed.py | 2 +- 3 files changed, 10 insertions(+), 14 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 5b066a753678..a17209f168c4 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3802,10 +3802,7 @@ def from_pretrained( model = cls(config, *model_args, **model_kwargs) # If we init with `zero3`, add an attr to the model so we can check downstream for issues - if is_deepspeed_zero3_enabled() and not is_quantized: - model.transformers_zero3_init_used = True - else: - model.transformers_zero3_init_used = False + model.transformers_zero3_init_used = is_deepspeed_zero3_enabled() and not is_quantized # make sure we use the model's config since the __init__ call might have copied it config = model.config diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 254fc5b5f7ee..742a5a9c8bcd 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -436,15 +436,14 @@ def __init__( ) self.model_init = model_init - if is_deepspeed_zero3_enabled(): - # Will reach this branch if the user has - # 1. Used `.from_pretrained` or `.from_config` to initialize their model - # 2. Did not configure Zero-3 via `TrainingArguments` or `accelerate launch` beforehand - # New models init such as `MyModel()` will not hit this step - if hasattr(model, "transformers_zero3_init_used") and not model.transformers_zero3_init_used: - raise ValueError( - "Model was not initialized with `Zero-3` despite being configured. Please re-initialize your model via `***Model.from_pretrained(...)` or `***Model.from_config(...)` after creating your `TrainingArguments`!" - ) + # Will reach this branch if the user has + # 1. Used `.from_pretrained` or `.from_config` to initialize their model + # 2. Did not configure Zero-3 via `TrainingArguments` or `accelerate launch` beforehand + # New models init such as `MyModel()` will not hit this step + if is_deepspeed_zero3_enabled() and not getattr(model, "transformers_zero3_init_used", True): + raise ValueError( + "Model was not initialized with `Zero-3` despite being configured. Please re-initialize your model via `***Model.from_pretrained(...)` or `***Model.from_config(...)` after creating your `TrainingArguments`!" + ) if model.__class__.__name__ in MODEL_MAPPING_NAMES: raise ValueError( diff --git a/tests/deepspeed/test_deepspeed.py b/tests/deepspeed/test_deepspeed.py index 429db7d8f03c..4e2ebeac91f8 100644 --- a/tests/deepspeed/test_deepspeed.py +++ b/tests/deepspeed/test_deepspeed.py @@ -725,7 +725,7 @@ def test_missed_zero3_init(self): model=model, args=training_args, ) - # Now do it proper, triggered from our `TrainingArguments` earlier + # Now do it properly, triggered from our `TrainingArguments` earlier model = AutoModel.from_pretrained(T5_TINY) trainer = Trainer( model=model, From 55c5f4dfcefef004fc505af333f6d01fa9b73a79 Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Thu, 1 Aug 2024 09:50:59 -0400 Subject: [PATCH 5/7] Get rid of stars in warning --- src/transformers/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 742a5a9c8bcd..1595ec1f2ebc 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -442,7 +442,7 @@ def __init__( # New models init such as `MyModel()` will not hit this step if is_deepspeed_zero3_enabled() and not getattr(model, "transformers_zero3_init_used", True): raise ValueError( - "Model was not initialized with `Zero-3` despite being configured. Please re-initialize your model via `***Model.from_pretrained(...)` or `***Model.from_config(...)` after creating your `TrainingArguments`!" + "Model was not initialized with `Zero-3` despite being configured. Please re-initialize your model via `Model.from_pretrained(...)` or `Model.from_config(...)` after creating your `TrainingArguments`!" ) if model.__class__.__name__ in MODEL_MAPPING_NAMES: From 0e992206c99adb131c24b7a0287012d2ca2e26ec Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Thu, 1 Aug 2024 09:52:38 -0400 Subject: [PATCH 6/7] Make private --- src/transformers/modeling_utils.py | 4 ++-- src/transformers/trainer.py | 2 +- tests/deepspeed/test_deepspeed.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index a17209f168c4..ecfb4d37d59a 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1470,7 +1470,7 @@ def _from_config(cls, config, **kwargs): model = cls(config, **kwargs) # Flag for if we init with `zero3`, add an attr to the model so we can check downstream for issues - model.transformers_zero3_init_used = is_deepspeed_zero3_enabled() + model._transformers_zero3_init_used = is_deepspeed_zero3_enabled() # restore default dtype if it was modified if dtype_orig is not None: @@ -3802,7 +3802,7 @@ def from_pretrained( model = cls(config, *model_args, **model_kwargs) # If we init with `zero3`, add an attr to the model so we can check downstream for issues - model.transformers_zero3_init_used = is_deepspeed_zero3_enabled() and not is_quantized + model._transformers_zero3_init_used = is_deepspeed_zero3_enabled() and not is_quantized # make sure we use the model's config since the __init__ call might have copied it config = model.config diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 1595ec1f2ebc..4a297bcadd3f 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -440,7 +440,7 @@ def __init__( # 1. Used `.from_pretrained` or `.from_config` to initialize their model # 2. Did not configure Zero-3 via `TrainingArguments` or `accelerate launch` beforehand # New models init such as `MyModel()` will not hit this step - if is_deepspeed_zero3_enabled() and not getattr(model, "transformers_zero3_init_used", True): + if is_deepspeed_zero3_enabled() and not getattr(model, "_transformers_zero3_init_used", True): raise ValueError( "Model was not initialized with `Zero-3` despite being configured. Please re-initialize your model via `Model.from_pretrained(...)` or `Model.from_config(...)` after creating your `TrainingArguments`!" ) diff --git a/tests/deepspeed/test_deepspeed.py b/tests/deepspeed/test_deepspeed.py index 4e2ebeac91f8..7b81ba40e47b 100644 --- a/tests/deepspeed/test_deepspeed.py +++ b/tests/deepspeed/test_deepspeed.py @@ -732,7 +732,7 @@ def test_missed_zero3_init(self): args=training_args, ) assert trainer.is_deepspeed_enabled - assert model.transformers_zero3_init_used + assert model._transformers_zero3_init_used def check_saved_checkpoints_deepspeed(self, output_dir, freq, total, stage, dtype): # adapted from TrainerIntegrationCommon.check_saved_checkpoints From 1ded36fdc35b3d842928d4a33b281cc3ce43973a Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Thu, 1 Aug 2024 09:53:03 -0400 Subject: [PATCH 7/7] Make clear --- src/transformers/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 4a297bcadd3f..e083b992d1b6 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -442,7 +442,7 @@ def __init__( # New models init such as `MyModel()` will not hit this step if is_deepspeed_zero3_enabled() and not getattr(model, "_transformers_zero3_init_used", True): raise ValueError( - "Model was not initialized with `Zero-3` despite being configured. Please re-initialize your model via `Model.from_pretrained(...)` or `Model.from_config(...)` after creating your `TrainingArguments`!" + "Model was not initialized with `Zero-3` despite being configured for DeepSpeed Zero-3. Please re-initialize your model via `Model.from_pretrained(...)` or `Model.from_config(...)` after creating your `TrainingArguments`!" ) if model.__class__.__name__ in MODEL_MAPPING_NAMES: