diff --git a/docs/source/paper_index.md b/docs/source/paper_index.md index bdc41263013..662c2f034f9 100644 --- a/docs/source/paper_index.md +++ b/docs/source/paper_index.md @@ -697,3 +697,27 @@ trainer.train() ``` For more details, see the [MiniLLM Trainer documentation](minillm) documentation. + +## Distributed Training + +### ZeRO: Memory Optimizations Toward Training Trillion Parameter Models + +**📜 Paper**: https://huggingface.co/papers/1910.02054 + +ZeRO (Zero Redundancy Optimizer) eliminates memory redundancies in data- and model-parallel training by partitioning optimizer states, gradients, and parameters across devices while retaining low communication volume and high computational granularity. This allows for the efficient training of large models that would otherwise not fit in GPU memory. + +TRL supports ZeRO via the [DeepSpeed integration](deepspeed_integration). To use it, provide a DeepSpeed configuration file with your desired settings, + +```yaml +# config.yaml +distributed_type: DEEPSPEED +num_processes: 2 +deepspeed_config: + zero_stage: 3 +``` + +and launch the training script using `accelerate launch --config_file config_file`. + +```sh +accelerate launch --config_file config.yaml train.py +``` diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 82cd59c2c04..2c85cac14d4 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -296,7 +296,11 @@ def __init__( # Model and reference model if isinstance(model, str): - model = create_model_from_path(model, **args.model_init_kwargs or {}) + model_init_kwargs = args.model_init_kwargs or {} + # Special case for DeepSpeed: requires device_map=None ("auto" fails) + if args.distributed_state.distributed_type == "DEEPSPEED": + model_init_kwargs["device_map"] = None + model = create_model_from_path(model, **model_init_kwargs) else: if args.model_init_kwargs is not None: logger.warning( @@ -305,7 +309,11 @@ def __init__( ) model_id = get_config_model_id(model.config) if isinstance(ref_model, str): - ref_model = create_model_from_path(ref_model, **args.ref_model_init_kwargs or {}) + model_init_kwargs = args.ref_model_init_kwargs or {} + # Special case for DeepSpeed: requires device_map=None ("auto" fails) + if args.distributed_state.distributed_type == "DEEPSPEED": + model_init_kwargs["device_map"] = None + ref_model = create_model_from_path(ref_model, **model_init_kwargs) else: if args.ref_model_init_kwargs is not None: logger.warning( diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index d1bdcebb810..d6c4e4501ea 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -603,7 +603,11 @@ def __init__( # Model if isinstance(model, str): - model = create_model_from_path(model, **args.model_init_kwargs or {}) + model_init_kwargs = args.model_init_kwargs or {} + # Special case for DeepSpeed: requires device_map=None ("auto" fails) + if args.distributed_state.distributed_type == "DEEPSPEED": + model_init_kwargs["device_map"] = None + model = create_model_from_path(model, **model_init_kwargs) else: if args.model_init_kwargs is not None: logger.warning(