Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 28 additions & 23 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4824,11 +4824,10 @@ def _load_pretrained_model(
# Warmup cuda to load the weights much faster on devices
if device_map is not None and hf_quantizer is None:
expanded_device_map = expand_device_map(device_map, expected_keys)
caching_allocator_warmup(model_to_load, expanded_device_map, dtype)
caching_allocator_warmup(model_to_load, expanded_device_map)

error_msgs = []
mismatched_keys = []
has_multiple_shards = len(checkpoint_files) > 1
# Iterate on all the shards to load the weights
for shard_file in checkpoint_files:
# Skip the load for shards that only contain disk-offloaded weights
Expand Down Expand Up @@ -4865,7 +4864,7 @@ def _load_pretrained_model(
prefix if loading_base_model_from_task_state_dict else "",
)

if low_cpu_mem_usage and shard_file is not None:
if low_cpu_mem_usage:
# Skip it with fsdp on ranks other than 0
if not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized):
disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model(
Expand Down Expand Up @@ -4893,10 +4892,8 @@ def _load_pretrained_model(
else:
model_to_load.load_state_dict(state_dict, strict=False, assign=assign_params)

# force memory release if loading multiple shards, to avoid having 2 state dicts in memory in next loop
del state_dict
# force memory release if loading multiple shards
if has_multiple_shards:
gc.collect()

# Adjust offloaded weights name and save if needed
if disk_offload_index is not None and len(disk_offload_index) > 0:
Expand Down Expand Up @@ -5789,11 +5786,24 @@ def expand_device_map(device_map, param_names):
return new_device_map


def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict, dtype: torch.dtype) -> Dict:
def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict):
"""This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
device. It allows to have one large call to Malloc, instead of recursively calling it later when loading
the model, which is actually the loading speed botteneck.
Calling this function allows to cut the model loading time by a very large margin.

A few facts related to loading speed (taking into account the use of this function):
- When loading a model the first time, it is usually slower than the subsequent times, because the OS is very likely
to cache the different state dicts (if enough ressources/RAM are available)
- Trying to force the OS to cache the files in advance (by e.g. accessing a small portion of them) is really hard,
and not a good idea in general as this is low level OS optimizations that depend on ressource usage anyway
- As of 18/03/2025, loading a Llama 70B model with TP takes ~1 min without file cache, and ~13s with full file cache.
The baseline, i.e. only loading the tensor shards on device and adjusting dtype (i.e. copying them) is ~5s with full cache.
These numbers are reported for TP on 4 H100 GPUs.
- It is useless to pre-allocate more than the model size in this function (i.e. using an `allocation_factor` > 1) as
cudaMalloc is not a bottleneck at all anymore
- Loading speed bottleneck is now almost only tensor copy (i.e. changing the dtype) and moving the tensors to the devices.
However, we cannot really improve on those aspects obviously, as the data needs to be moved/copied in the end.
"""
# Remove disk and cpu devices, and cast to proper torch.device
accelerator_device_map = {
Expand All @@ -5808,31 +5818,26 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict,
else None
)

parameter_count = defaultdict(lambda: 0)
allocation_factor = 1
if torch.distributed.is_initialized() or len(set(accelerator_device_map.values())) >= 2:
allocation_factor = 2

total_byte_count = defaultdict(lambda: 0)
for param_name, device in accelerator_device_map.items():
param = model.get_parameter_or_buffer(param_name)
param_size = int(math.prod(param.shape) * allocation_factor)
# The dtype of different parameters may be different with composite models or `keep_in_fp32_modules`
param_byte_count = math.prod(param.shape) * dtype_byte_size(param.dtype)

if tp_plan_regex is not None:
generic_name = re.sub(r"\.\d+\.", ".*.", param_name)
param_size //= torch.distributed.get_world_size() if tp_plan_regex.search(generic_name) else 1

parameter_count[device] += param_size
param_byte_count //= torch.distributed.get_world_size() if tp_plan_regex.search(generic_name) else 1

dtype = dtype if dtype is not None else torch.float32
total_byte_count[device] += param_byte_count

# This will kick off the caching allocator to avoid having to Malloc afterwards
for device, param_count in parameter_count.items():
max_memory_device = None
for device, byte_count in total_byte_count.items():
if device.type == "cuda":
max_memory_device = torch.cuda.mem_get_info(device.index)[0]
# allocate only if we have enough memory
if max_memory_device is None or max_memory_device > param_count * dtype_byte_size(dtype):
_ = torch.empty(param_count, dtype=dtype, device=device, requires_grad=False)
device_memory = torch.cuda.mem_get_info(device)[0]
# Allow up to 95% of max device memory
byte_count = min(byte_count, int(0.95 * device_memory))
# Allocate memory
_ = torch.empty(byte_count // 2, dtype=torch.float16, device=device, requires_grad=False)


def get_disk_only_shard_files(device_map, weight_map):
Expand Down