Skip to content

Commit

Permalink
use weights iterator while loading (#2886)
Browse files Browse the repository at this point in the history
* use weights iterator

* optimize load tp colwise

* optimize w8 and adapter

* fix vl

* update runtime

* fix dsv2 gate

* sort before load

* remove tqdm

---------

Co-authored-by: grimoire <[email protected]>
  • Loading branch information
RunningLeon and grimoire authored Jan 20, 2025
1 parent 832bfc4 commit 3f8b079
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 40 deletions.
5 changes: 4 additions & 1 deletion lmdeploy/pytorch/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def forward(self, hidden_states: torch.Tensor):
if self.topk_method == 'greedy':
topk_weight, topk_idx = self.softmax_topk(router_logits)
elif self.topk_method == 'group_limited_greedy':
scores = self._compute_scores(router_logits)
scores = router_logits
grouped_logits = scores.unflatten(-1, (self.n_group, -1))
group_scores = (grouped_logits.max(-1).values)
group_idx = torch.topk(group_scores,
Expand Down Expand Up @@ -758,6 +758,7 @@ def _load_weight_attention(self, name: str, loaded_weight: torch.Tensor,
params_dict: Dict[str, nn.Parameter],
update_pe_mapping: List):
"""load weight attention."""
device = next(iter(params_dict.values())).device

def __update_pe(weight, head_dim: int, pe_dim_offset: int):
# (num_heads, q_head_dim, input_dim)
Expand Down Expand Up @@ -828,6 +829,7 @@ def __load_kcvc_blocked_fp8(name: str, loaded_weight: torch.Tensor):
if name.endswith('.scale'):
weight = loaded_weight
else:
loaded_weight = loaded_weight.to(device)
weight = __update_pe(loaded_weight, head_dim, pe_dim_offset)
param = params_dict[name]
load_weight(param, weight)
Expand All @@ -839,6 +841,7 @@ def __load_kcvc_blocked_fp8(name: str, loaded_weight: torch.Tensor):
if quantization_config is not None:
quant_method = quantization_config.get('quant_method')

loaded_weight = loaded_weight.to(device)
if quant_method == 'fp8':
# update blocked fp8 weight
__load_kcvc_blocked_fp8(name, loaded_weight)
Expand Down
12 changes: 4 additions & 8 deletions lmdeploy/pytorch/models/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,9 +536,13 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
"""load weights."""

lang_prefix = 'language_model.'
lang_prefix_length = len(lang_prefix)
new_weights = dict()
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if name.startswith(lang_prefix):
new_key = name[lang_prefix_length:]
new_weights[new_key] = loaded_weight
continue

if 'qkv' in name:
Expand All @@ -551,14 +555,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
param = params_dict[name]
load_weight(param, loaded_weight)

lang_prefix_length = len(lang_prefix)
new_weights = dict()
for key, val in weights:
if not key.startswith(lang_prefix):
continue
new_key = key[lang_prefix_length:]
new_weights[new_key] = val

self.language_model.load_weights(new_weights.items())

def get_input_processor(self) -> BaseModelInputProcessor:
Expand Down
13 changes: 4 additions & 9 deletions lmdeploy/pytorch/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,9 +587,13 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

# vis model
lang_prefix = 'language_model.'
prefix_length = len(lang_prefix)
new_weights = dict()
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if name.startswith(lang_prefix):
new_key = name[prefix_length:]
new_weights[new_key] = loaded_weight
continue

for (param_name, weight_name, shard_id) in stacked_params_mapping:
Expand All @@ -603,15 +607,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
param = params_dict[name]
load_weight(param, loaded_weight)

# language model
prefix_length = len(lang_prefix)
new_weights = dict()
for key, val in weights:
if not key.startswith(lang_prefix):
continue
new_key = key[prefix_length:]
new_weights[new_key] = val

self.language_model.load_weights(new_weights.items())

def get_input_processor(self) -> BaseModelInputProcessor:
Expand Down
4 changes: 4 additions & 0 deletions lmdeploy/pytorch/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def weight_loader_A(self, param: nn.Parameter, loaded_weight: torch.Tensor,

if self.is_tp and not self.colwise:
world_size, rank = get_world_rank()
loaded_weight = loaded_weight.to(param_r.device)
loaded_weight = loaded_weight.chunk(world_size, dim=1)[rank]

param_r.copy_(loaded_weight)
Expand Down Expand Up @@ -226,6 +227,7 @@ def _weight_loader_tp_rowwise(self, param: torch.nn.Parameter,
world_size: int):
"""weight loader for rowwise linear."""
if loaded_weight.dim() == 2:
loaded_weight = loaded_weight.to(param.device)
weight = loaded_weight.chunk(world_size, 1)[rank]
return default_weight_loader(param, weight)
else:
Expand Down Expand Up @@ -887,6 +889,7 @@ def _weight_loader_tp_rowwise(self, param: torch.nn.Parameter,
if loaded_weight.dim() == 2 and param.dtype in (torch.int8,
torch.float8_e4m3fn,
torch.float8_e5m2):
loaded_weight = loaded_weight.to(param.device)
weight = loaded_weight.chunk(world_size, 1)[rank]
return default_weight_loader(param, weight)
elif loaded_weight.dim() == 2 and loaded_weight.size(1) == 1:
Expand Down Expand Up @@ -1188,6 +1191,7 @@ def _weight_loader_tp_rowwise(self, param: torch.nn.Parameter,
world_size: int):
"""weight loader for rowwise linear."""
if loaded_weight.dim() == 2:
loaded_weight = loaded_weight.to(param.device)
weight = _chunk_align(loaded_weight, world_size, 1,
self.tp_align_size)[rank]
return default_weight_loader(param, weight)
Expand Down
5 changes: 4 additions & 1 deletion lmdeploy/pytorch/nn/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ def weight_loader_tp(self, param: torch.nn.Parameter,
weight = loaded_weight.chunk(world_size, dim=0)[rank]
elif shard_id == 'down':
param_data = param.data[expert_id]
weight = loaded_weight.chunk(world_size, dim=1)[rank]
# weight is not contiguous, chunk and copy in cpu is slow
weight = loaded_weight.to(param_data.device)
weight = weight.chunk(world_size, dim=1)[rank]
else:
raise RuntimeError(f'Unknown shard_id: {shard_id}')
param_data.copy_(weight)
Expand Down Expand Up @@ -403,6 +405,7 @@ def weight_loader_scale_tp(self, param: torch.nn.Parameter,
weight = loaded_weight.chunk(world_size, dim=0)[rank]
elif shard_id == 'down':
param_data = param.data[expert_id]
loaded_weight = loaded_weight.to(param_data.device)
weight = loaded_weight.chunk(world_size, dim=1)[rank]
else:
raise RuntimeError(f'Unknown shard_id: {shard_id}')
Expand Down
62 changes: 41 additions & 21 deletions lmdeploy/pytorch/weight_loader/model_weight_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import os.path as osp

import torch
from transformers.modeling_utils import load_state_dict
from safetensors.torch import safe_open
from tqdm.auto import tqdm
from transformers.utils import (SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME, WEIGHTS_NAME)

Expand Down Expand Up @@ -90,6 +91,28 @@ def _get_weight_path(model_path: str, weight_type: str):
return weight_path, weight_name


def _get_safetensors_weights_iterator(file: str, prefix: str):
"""get safeternsors weights iterator."""
with safe_open(file, framework='pt') as f:
for name in f.keys():
param = f.get_tensor(name)
if prefix is not None:
name = f'{prefix}{name}'
yield name, param


def _get_pt_weights_iterator(file: str, prefix: str):
"""get pt weights iterator."""
state = torch.load(file, weights_only=True, map_location='cpu')
if prefix is None:
yield from state.items()
else:
for k, v in state.items():
yield f'{prefix}{k}', v
del state
torch.cuda.empty_cache()


class ModelWeightLoader:
"""model weight loader for sharded weights."""

Expand All @@ -115,13 +138,14 @@ def _get_shard_paths(model_path: str, is_sharded: bool, weight_type: str):
path, _ = _get_weight_path(model_path, weight_type)
return (path, )

def _load_shard(self, path: str):
"""load shards."""
state_dict = load_state_dict(path)
if self._prefix is not None:
state_dict = dict(
(f'{self._prefix}{k}', v) for k, v in state_dict.items())
return state_dict
def _get_weights_iterator(self, path: str):
"""get weights iterator."""
if self._weight_type == 'safetensors':
weights_iterator = _get_safetensors_weights_iterator(
path, self._prefix)
else:
weights_iterator = _get_pt_weights_iterator(path, self._prefix)
return weights_iterator

def load_model_weights(
self,
Expand All @@ -131,19 +155,15 @@ def load_model_weights(
"""load model weights implementation."""
assert hasattr(model, 'load_weights')
paths = self._shard_paths
world_size, rank = get_world_rank()
for path in paths:

# log
file_name = osp.split(path)[1]
msg = f'loading weights - "{file_name}"'
if world_size > 1:
msg = f'rank[{rank}] {msg}'
logger.info(msg)

# process
state_dict = self._load_shard(path)
model.load_weights(state_dict.items())
_, rank = get_world_rank()
disable_tqdm = rank != 0

paths = sorted(paths)
for path in tqdm(paths,
desc='Loading weights from safetensors',
disable=disable_tqdm):
weights_iterator = self._get_weights_iterator(path)
model.load_weights(weights_iterator)
if device is not None:
device = model.to(device)

Expand Down

0 comments on commit 3f8b079

Please sign in to comment.