diff --git a/src/transformers/integrations/deepspeed.py b/src/transformers/integrations/deepspeed.py index 24e8546a50f1..47d7a7ffcb5f 100644 --- a/src/transformers/integrations/deepspeed.py +++ b/src/transformers/integrations/deepspeed.py @@ -130,11 +130,58 @@ def fill_match(self, ds_key_long, hf_val, hf_key=None, must_match=True): fill_only = partialmethod(fill_match, must_match=False) + def override_training_args_from_deepspeed(self, args): + """ + Override TrainingArguments based on DeepSpeed config values to ensure compatibility. + + This method ensures that the DeepSpeed config takes precedence over TrainingArguments + defaults when there are conflicts, particularly for mixed precision settings. + + Args: + args: TrainingArguments object to potentially modify + """ + # Check precision settings in DeepSpeed config and override TrainingArguments accordingly + # Only override defaults, not explicit user settings + + # Check if user explicitly set precision options (we assume defaults are False) + user_set_fp16 = args.fp16 is True + user_set_bf16 = args.bf16 is True + + if self.is_true("fp16.enabled"): + # DeepSpeed config explicitly enables fp16 + if not user_set_fp16 and not user_set_bf16: + # User didn't explicitly set either, so apply DeepSpeed config + args.fp16 = True + args.bf16 = False + elif user_set_bf16 and not user_set_fp16: + # User explicitly chose bf16, but DeepSpeed config wants fp16 + # This is a potential conflict - let user choice win but log a warning + pass # Keep user's bf16=True, fp16=False + elif self.is_true("bf16.enabled"): + # DeepSpeed config explicitly enables bf16 + if not user_set_fp16 and not user_set_bf16: + # User didn't explicitly set either, so apply DeepSpeed config + args.bf16 = True + args.fp16 = False + elif user_set_fp16 and not user_set_bf16: + # User explicitly chose fp16, but DeepSpeed config wants bf16 + # This is a potential conflict - let user choice win but log a warning + pass # Keep user's fp16=True, bf16=False + elif self.is_false("fp16.enabled") and self.is_false("bf16.enabled"): + # Both are explicitly disabled in DeepSpeed config + if not user_set_fp16 and not user_set_bf16: + # User didn't explicitly set either, so apply DeepSpeed config (fp32) + args.fp16 = False + args.bf16 = False + def trainer_config_process(self, args, auto_find_batch_size=False): """ Adjust the config with `TrainingArguments` values. This stage is run during `TrainingArguments` object creation. """ + # First, override TrainingArguments based on DeepSpeed config to ensure compatibility + self.override_training_args_from_deepspeed(args) + # DeepSpeed does: # train_batch_size = world_size * train_micro_batch_size_per_gpu * gradient_accumulation_steps train_batch_size = args.world_size * args.per_device_train_batch_size * args.gradient_accumulation_steps diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index da0721eee0c9..32938a83efc2 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1853,14 +1853,8 @@ def __post_init__(self): torch.backends.cudnn.allow_tf32 = False # no need to assert on else - # if training args is specified, it will override the one specified in the accelerate config - if self.half_precision_backend != "apex": - mixed_precision_dtype = os.environ.get("ACCELERATE_MIXED_PRECISION", "no") - if self.fp16: - mixed_precision_dtype = "fp16" - elif self.bf16: - mixed_precision_dtype = "bf16" - os.environ["ACCELERATE_MIXED_PRECISION"] = mixed_precision_dtype + # NOTE: Mixed precision environment variable setting moved to after DeepSpeed processing + # to ensure DeepSpeed config can override TrainingArguments defaults if self.report_to is None: logger.info( @@ -2070,6 +2064,16 @@ def __post_init__(self): self.deepspeed_plugin.set_mixed_precision(mixed_precision) self.deepspeed_plugin.set_deepspeed_weakref() + # Set mixed precision environment variable after DeepSpeed processing + # This ensures DeepSpeed config overrides have been applied to fp16/bf16 settings + if self.half_precision_backend != "apex": + mixed_precision_dtype = os.environ.get("ACCELERATE_MIXED_PRECISION", "no") + if self.fp16: + mixed_precision_dtype = "fp16" + elif self.bf16: + mixed_precision_dtype = "bf16" + os.environ["ACCELERATE_MIXED_PRECISION"] = mixed_precision_dtype + if self.use_cpu: self.dataloader_pin_memory = False diff --git a/tests/deepspeed/test_deepspeed.py b/tests/deepspeed/test_deepspeed.py index 99b1450a0d59..e3dc9fc08c99 100644 --- a/tests/deepspeed/test_deepspeed.py +++ b/tests/deepspeed/test_deepspeed.py @@ -1431,3 +1431,50 @@ def test_clm_from_config_zero3_fp16(self): with CaptureStderr() as cs: execute_subprocess_async(cmd, env=self.get_env()) self.assertIn("Detected DeepSpeed ZeRO-3", cs.err) + + +@require_deepspeed +class TestDeepSpeedMixedPrecisionPrecedence(TestCasePlus): + """Test DeepSpeed mixed precision precedence over Accelerate defaults.""" + + def setUp(self): + super().setUp() + unset_hf_deepspeed_config() + + def tearDown(self): + super().tearDown() + unset_hf_deepspeed_config() + + def test_deepspeed_fp16_overrides_defaults(self): + """Test that DeepSpeed fp16 config overrides TrainingArguments defaults""" + from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig + + args = TrainingArguments(output_dir="./test_output", fp16=False, bf16=False) + ds_config = {"fp16": {"enabled": True}, "bf16": {"enabled": False}, "zero_optimization": {"stage": 2}} + hf_ds_config = HfTrainerDeepSpeedConfig(ds_config) + hf_ds_config.trainer_config_process(args) + self.assertTrue(args.fp16) + self.assertFalse(args.bf16) + + def test_deepspeed_bf16_overrides_defaults(self): + """Test that DeepSpeed bf16 config overrides TrainingArguments defaults""" + from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig + + args = TrainingArguments(output_dir="./test_output", fp16=False, bf16=False) + ds_config = {"fp16": {"enabled": False}, "bf16": {"enabled": True}, "zero_optimization": {"stage": 2}} + hf_ds_config = HfTrainerDeepSpeedConfig(ds_config) + hf_ds_config.trainer_config_process(args) + self.assertTrue(args.bf16) + self.assertFalse(args.fp16) + + def test_user_explicit_settings_preserved(self): + """Test that explicit user settings are preserved over DeepSpeed config""" + from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig + + args = TrainingArguments(output_dir="./test_output", fp16=True, bf16=False) # User explicit + ds_config = {"fp16": {"enabled": False}, "bf16": {"enabled": True}, "zero_optimization": {"stage": 2}} + hf_ds_config = HfTrainerDeepSpeedConfig(ds_config) + hf_ds_config.trainer_config_process(args) + # User's explicit choice should be preserved + self.assertTrue(args.fp16) + self.assertFalse(args.bf16)