diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 9c99a6d770d1..b5df36e12a94 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3443,6 +3443,29 @@ def from_pretrained( # TODO: we can relax this check when we support taking tp_plan from a json file, for example. raise ValueError(f"tp_plan supports 'auto' only for now but got {tp_plan}.") + if tp_plan is not None and device_map is not None: + raise ValueError( + "`tp_plan` and `device_map` are mutually exclusive. Choose either one for parallelization." + ) + + # We need to correctly dispatch the model on the current process device. The easiest way for this is to use a simple + # `device_map` pointing to the correct device. If we don't, torch will use the default device (index 0) for all + # childs processes at parallelization time, resulting in excessive memory usage on device 0 and OOMs. + # And temporarily setting the default device to current process rank result in the following error + # `torch.distributed.DistBackendError: Attempt to perform collective on tensor not on device passed to init_process_group` + tp_device = None + if tp_plan is not None: + if not torch.distributed.is_initialized(): + raise ValueError("Tensor Parallel requires torch.distributed to be initialized first.") + + # Detect the accelerator on the machine. If no accelerator is available, it returns CPU. + device_type = torch._C._get_accelerator().type + device_module = torch.get_device_module(device_type) + # Get device with index assuming equal number of devices per host + tp_device = torch.device(device_type, torch.distributed.get_rank() % device_module.device_count()) + # This is the easiest way to dispatch to the current process device + device_map = tp_device + if is_fsdp_enabled(): low_cpu_mem_usage = True @@ -4090,7 +4113,6 @@ def from_pretrained( # Instantiate model. init_contexts = [no_init_weights(_enable=_fast_init)] - tp_device = None if is_deepspeed_zero3_enabled() and not is_quantized and not _is_ds_init_called: import deepspeed @@ -4106,16 +4128,6 @@ def from_pretrained( f"Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`" ) init_contexts.append(init_empty_weights()) - elif tp_plan is not None: - if not torch.distributed.is_initialized(): - raise ValueError("Tensor Parallel requires torch.distributed to be initialized first.") - - # Detect the accelerator on the machine. If no accelerator is available, it returns CPU. - device_type = torch._C._get_accelerator().type - device_module = torch.get_device_module(device_type) - # Get device with index assuming equal number of devices per host - tp_device = torch.device(device_type, torch.distributed.get_rank() % device_module.device_count()) - init_contexts.append(tp_device) if is_deepspeed_zero3_enabled() and is_quantized: init_contexts.append(set_quantized_state()) @@ -4249,38 +4261,32 @@ def from_pretrained( if dtype_orig is not None: torch.set_default_dtype(dtype_orig) - load_contexts = [] - # Make sure we load onto targeted device - if tp_device is not None: - load_contexts.append(tp_device) - - with ContextManagers(load_contexts): - ( - model, - missing_keys, - unexpected_keys, - mismatched_keys, - offload_index, - error_msgs, - ) = cls._load_pretrained_model( - model, - state_dict, - loaded_state_dict_keys, # XXX: rename? - resolved_archive_file, - pretrained_model_name_or_path, - ignore_mismatched_sizes=ignore_mismatched_sizes, - sharded_metadata=sharded_metadata, - _fast_init=_fast_init, - low_cpu_mem_usage=low_cpu_mem_usage, - device_map=device_map, - offload_folder=offload_folder, - offload_state_dict=offload_state_dict, - dtype=torch_dtype, - hf_quantizer=hf_quantizer, - keep_in_fp32_modules=keep_in_fp32_modules, - gguf_path=gguf_path, - weights_only=weights_only, - ) + ( + model, + missing_keys, + unexpected_keys, + mismatched_keys, + offload_index, + error_msgs, + ) = cls._load_pretrained_model( + model, + state_dict, + loaded_state_dict_keys, # XXX: rename? + resolved_archive_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=ignore_mismatched_sizes, + sharded_metadata=sharded_metadata, + _fast_init=_fast_init, + low_cpu_mem_usage=low_cpu_mem_usage, + device_map=device_map, + offload_folder=offload_folder, + offload_state_dict=offload_state_dict, + dtype=torch_dtype, + hf_quantizer=hf_quantizer, + keep_in_fp32_modules=keep_in_fp32_modules, + gguf_path=gguf_path, + weights_only=weights_only, + ) # make sure token embedding weights are still tied if needed model.tie_weights() diff --git a/tests/tp/test_tp.py b/tests/tp/test_tp.py index 2139a648867b..3df57c5f955c 100644 --- a/tests/tp/test_tp.py +++ b/tests/tp/test_tp.py @@ -13,6 +13,9 @@ # limitations under the License. import os +import subprocess +import tempfile +import textwrap from transformers import is_torch_available from transformers.models.llama.configuration_llama import LlamaConfig @@ -30,6 +33,22 @@ class TestTensorParallel(TestCasePlus): + def torchrun(self, script: str): + """Run the `script` using `torchrun` command for multi-processing in a subprocess. Captures errors as necesary.""" + with tempfile.NamedTemporaryFile(mode="w+", suffix=".py") as tmp: + tmp.write(script) + tmp.flush() + tmp.seek(0) + cmd = ( + f"torchrun --nproc_per_node {torch.cuda.device_count()} --master_port {get_torch_dist_unique_port()} {tmp.name}" + ).split() + + # Note that the subprocess will be waited for here, and raise an error if not successful + try: + _ = subprocess.run(cmd, capture_output=True, env=self.get_env(), text=True, check=True) + except subprocess.CalledProcessError as e: + raise Exception(f"The following error was captured: {e.stderr}") + @require_torch_multi_gpu def test_tp(self): distributed_args = f"""--nproc_per_node={torch.cuda.device_count()} @@ -43,6 +62,42 @@ def test_tp(self): execute_subprocess_async(cmd, env=self.get_env()) # successful return here == success - any errors would have caused an error in the sub-call + @require_torch_multi_gpu + def test_loading_memory_consumption(self): + script_to_run = textwrap.dedent( + """ + import torch + import os + from transformers import AutoModelForCausalLM + + model_id = "meta-llama/Meta-Llama-3-8B-Instruct" + + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + device = torch.device(f"cuda:{rank}") + torch.distributed.init_process_group("nccl", device_id=device) + + model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, tp_plan="auto") + torch.distributed.barrier() + + # The expected full model memory footprint + expected_model_memory = 16 + overhead_factor = 1.2 + + # Assert we did not use more than the full model expected memory (with some overhead) + if not torch.cuda.max_memory_allocated(device) / 1024**3 < expected_model_memory * overhead_factor: + raise ValueError("Loading the model used more than the full model size") + + # Assert we correctly handled the sharding between devices + if not torch.cuda.memory_allocated(device) / 1024**3 < (expected_model_memory / world_size) * overhead_factor: + raise ValueError("Each model shard is larger than what is expected.") + + torch.distributed.barrier() + torch.distributed.destroy_process_group() + """ + ) + self.torchrun(script_to_run) + if __name__ == "__main__": # The script below is meant to be run under torch.distributed, on a machine with multiple GPUs: