diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 8189c22fe5a1..a8d3258afedf 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1627,6 +1627,12 @@ def __post_init__(self): os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" self.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.hf_deepspeed_config) + elif strtobool(os.environ.get("ACCELERATE_USE_DEEPSPEED", "false")): + # Accelerate DeepSpeed Plugin + from accelerate.utils import DeepSpeedPlugin + + self.deepspeed_plugin = DeepSpeedPlugin() + self.deepspeed_plugin.set_deepspeed_weakref() if self.push_to_hub_token is not None: warnings.warn(