Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 49 additions & 43 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3443,6 +3443,29 @@ def from_pretrained(
# TODO: we can relax this check when we support taking tp_plan from a json file, for example.
raise ValueError(f"tp_plan supports 'auto' only for now but got {tp_plan}.")

if tp_plan is not None and device_map is not None:
raise ValueError(
"`tp_plan` and `device_map` are mutually exclusive. Choose either one for parallelization."
)

# We need to correctly dispatch the model on the current process device. The easiest way for this is to use a simple
# `device_map` pointing to the correct device. If we don't, torch will use the default device (index 0) for all
# childs processes at parallelization time, resulting in excessive memory usage on device 0 and OOMs.
# And temporarily setting the default device to current process rank result in the following error
# `torch.distributed.DistBackendError: Attempt to perform collective on tensor not on device passed to init_process_group`
tp_device = None
if tp_plan is not None:
if not torch.distributed.is_initialized():
raise ValueError("Tensor Parallel requires torch.distributed to be initialized first.")

# Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
device_type = torch._C._get_accelerator().type
device_module = torch.get_device_module(device_type)
# Get device with index assuming equal number of devices per host
tp_device = torch.device(device_type, torch.distributed.get_rank() % device_module.device_count())
# This is the easiest way to dispatch to the current process device
device_map = tp_device

if is_fsdp_enabled():
low_cpu_mem_usage = True

Expand Down Expand Up @@ -4090,7 +4113,6 @@ def from_pretrained(

# Instantiate model.
init_contexts = [no_init_weights(_enable=_fast_init)]
tp_device = None

if is_deepspeed_zero3_enabled() and not is_quantized and not _is_ds_init_called:
import deepspeed
Expand All @@ -4106,16 +4128,6 @@ def from_pretrained(
f"Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
)
init_contexts.append(init_empty_weights())
elif tp_plan is not None:
if not torch.distributed.is_initialized():
raise ValueError("Tensor Parallel requires torch.distributed to be initialized first.")

# Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
device_type = torch._C._get_accelerator().type
device_module = torch.get_device_module(device_type)
# Get device with index assuming equal number of devices per host
tp_device = torch.device(device_type, torch.distributed.get_rank() % device_module.device_count())
init_contexts.append(tp_device)

if is_deepspeed_zero3_enabled() and is_quantized:
init_contexts.append(set_quantized_state())
Expand Down Expand Up @@ -4249,38 +4261,32 @@ def from_pretrained(
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)

load_contexts = []
# Make sure we load onto targeted device
if tp_device is not None:
load_contexts.append(tp_device)

with ContextManagers(load_contexts):
(
model,
missing_keys,
unexpected_keys,
mismatched_keys,
offload_index,
error_msgs,
) = cls._load_pretrained_model(
model,
state_dict,
loaded_state_dict_keys, # XXX: rename?
resolved_archive_file,
pretrained_model_name_or_path,
ignore_mismatched_sizes=ignore_mismatched_sizes,
sharded_metadata=sharded_metadata,
_fast_init=_fast_init,
low_cpu_mem_usage=low_cpu_mem_usage,
device_map=device_map,
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
gguf_path=gguf_path,
weights_only=weights_only,
)
(
model,
missing_keys,
unexpected_keys,
mismatched_keys,
offload_index,
error_msgs,
) = cls._load_pretrained_model(
model,
state_dict,
loaded_state_dict_keys, # XXX: rename?
resolved_archive_file,
pretrained_model_name_or_path,
ignore_mismatched_sizes=ignore_mismatched_sizes,
sharded_metadata=sharded_metadata,
_fast_init=_fast_init,
low_cpu_mem_usage=low_cpu_mem_usage,
device_map=device_map,
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
gguf_path=gguf_path,
weights_only=weights_only,
)

# make sure token embedding weights are still tied if needed
model.tie_weights()
Expand Down
55 changes: 55 additions & 0 deletions tests/tp/test_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
# limitations under the License.

import os
import subprocess
import tempfile
import textwrap

from transformers import is_torch_available
from transformers.models.llama.configuration_llama import LlamaConfig
Expand All @@ -30,6 +33,22 @@


class TestTensorParallel(TestCasePlus):
def torchrun(self, script: str):
"""Run the `script` using `torchrun` command for multi-processing in a subprocess. Captures errors as necesary."""
with tempfile.NamedTemporaryFile(mode="w+", suffix=".py") as tmp:
tmp.write(script)
tmp.flush()
tmp.seek(0)
cmd = (
f"torchrun --nproc_per_node {torch.cuda.device_count()} --master_port {get_torch_dist_unique_port()} {tmp.name}"
).split()

# Note that the subprocess will be waited for here, and raise an error if not successful
try:
_ = subprocess.run(cmd, capture_output=True, env=self.get_env(), text=True, check=True)
except subprocess.CalledProcessError as e:
raise Exception(f"The following error was captured: {e.stderr}")

@require_torch_multi_gpu
def test_tp(self):
distributed_args = f"""--nproc_per_node={torch.cuda.device_count()}
Expand All @@ -43,6 +62,42 @@ def test_tp(self):
execute_subprocess_async(cmd, env=self.get_env())
# successful return here == success - any errors would have caused an error in the sub-call

@require_torch_multi_gpu
def test_loading_memory_consumption(self):
script_to_run = textwrap.dedent(
"""
import torch
import os
from transformers import AutoModelForCausalLM

model_id = "meta-llama/Meta-Llama-3-8B-Instruct"

rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
device = torch.device(f"cuda:{rank}")
torch.distributed.init_process_group("nccl", device_id=device)

model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, tp_plan="auto")
torch.distributed.barrier()

# The expected full model memory footprint
expected_model_memory = 16
overhead_factor = 1.2

# Assert we did not use more than the full model expected memory (with some overhead)
if not torch.cuda.max_memory_allocated(device) / 1024**3 < expected_model_memory * overhead_factor:
raise ValueError("Loading the model used more than the full model size")

# Assert we correctly handled the sharding between devices
if not torch.cuda.memory_allocated(device) / 1024**3 < (expected_model_memory / world_size) * overhead_factor:
raise ValueError("Each model shard is larger than what is expected.")

torch.distributed.barrier()
torch.distributed.destroy_process_group()
"""
)
self.torchrun(script_to_run)


if __name__ == "__main__":
# The script below is meant to be run under torch.distributed, on a machine with multiple GPUs:
Expand Down