Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use weights iterator while loading #2886

Merged
merged 9 commits into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion lmdeploy/pytorch/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def forward(self, hidden_states: torch.Tensor):
if self.topk_method == 'greedy':
topk_weight, topk_idx = self.softmax_topk(router_logits)
elif self.topk_method == 'group_limited_greedy':
scores = self._compute_scores(router_logits)
scores = router_logits
grouped_logits = scores.unflatten(-1, (self.n_group, -1))
group_scores = (grouped_logits.max(-1).values)
group_idx = torch.topk(group_scores,
Expand Down Expand Up @@ -758,6 +758,7 @@ def _load_weight_attention(self, name: str, loaded_weight: torch.Tensor,
params_dict: Dict[str, nn.Parameter],
update_pe_mapping: List):
"""load weight attention."""
device = next(iter(params_dict.values())).device

def __update_pe(weight, head_dim: int, pe_dim_offset: int):
# (num_heads, q_head_dim, input_dim)
Expand Down Expand Up @@ -828,6 +829,7 @@ def __load_kcvc_blocked_fp8(name: str, loaded_weight: torch.Tensor):
if name.endswith('.scale'):
weight = loaded_weight
else:
loaded_weight = loaded_weight.to(device)
weight = __update_pe(loaded_weight, head_dim, pe_dim_offset)
param = params_dict[name]
load_weight(param, weight)
Expand All @@ -839,6 +841,7 @@ def __load_kcvc_blocked_fp8(name: str, loaded_weight: torch.Tensor):
if quantization_config is not None:
quant_method = quantization_config.get('quant_method')

loaded_weight = loaded_weight.to(device)
if quant_method == 'fp8':
# update blocked fp8 weight
__load_kcvc_blocked_fp8(name, loaded_weight)
Expand Down
12 changes: 4 additions & 8 deletions lmdeploy/pytorch/models/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,9 +536,13 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
"""load weights."""

lang_prefix = 'language_model.'
lang_prefix_length = len(lang_prefix)
new_weights = dict()
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if name.startswith(lang_prefix):
new_key = name[lang_prefix_length:]
new_weights[new_key] = loaded_weight
continue

if 'qkv' in name:
Expand All @@ -551,14 +555,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
param = params_dict[name]
load_weight(param, loaded_weight)

lang_prefix_length = len(lang_prefix)
new_weights = dict()
for key, val in weights:
if not key.startswith(lang_prefix):
continue
new_key = key[lang_prefix_length:]
new_weights[new_key] = val

self.language_model.load_weights(new_weights.items())

def get_input_processor(self) -> BaseModelInputProcessor:
Expand Down
13 changes: 4 additions & 9 deletions lmdeploy/pytorch/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,9 +587,13 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

# vis model
lang_prefix = 'language_model.'
prefix_length = len(lang_prefix)
new_weights = dict()
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if name.startswith(lang_prefix):
new_key = name[prefix_length:]
new_weights[new_key] = loaded_weight
continue

for (param_name, weight_name, shard_id) in stacked_params_mapping:
Expand All @@ -603,15 +607,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
param = params_dict[name]
load_weight(param, loaded_weight)

# language model
prefix_length = len(lang_prefix)
new_weights = dict()
for key, val in weights:
if not key.startswith(lang_prefix):
continue
new_key = key[prefix_length:]
new_weights[new_key] = val

self.language_model.load_weights(new_weights.items())

def get_input_processor(self) -> BaseModelInputProcessor:
Expand Down
4 changes: 4 additions & 0 deletions lmdeploy/pytorch/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def weight_loader_A(self, param: nn.Parameter, loaded_weight: torch.Tensor,

if self.is_tp and not self.colwise:
world_size, rank = get_world_rank()
loaded_weight = loaded_weight.to(param_r.device)
loaded_weight = loaded_weight.chunk(world_size, dim=1)[rank]

param_r.copy_(loaded_weight)
Expand Down Expand Up @@ -226,6 +227,7 @@ def _weight_loader_tp_rowwise(self, param: torch.nn.Parameter,
world_size: int):
"""weight loader for rowwise linear."""
if loaded_weight.dim() == 2:
loaded_weight = loaded_weight.to(param.device)
weight = loaded_weight.chunk(world_size, 1)[rank]
return default_weight_loader(param, weight)
else:
Expand Down Expand Up @@ -887,6 +889,7 @@ def _weight_loader_tp_rowwise(self, param: torch.nn.Parameter,
if loaded_weight.dim() == 2 and param.dtype in (torch.int8,
torch.float8_e4m3fn,
torch.float8_e5m2):
loaded_weight = loaded_weight.to(param.device)
weight = loaded_weight.chunk(world_size, 1)[rank]
return default_weight_loader(param, weight)
elif loaded_weight.dim() == 2 and loaded_weight.size(1) == 1:
Expand Down Expand Up @@ -1188,6 +1191,7 @@ def _weight_loader_tp_rowwise(self, param: torch.nn.Parameter,
world_size: int):
"""weight loader for rowwise linear."""
if loaded_weight.dim() == 2:
loaded_weight = loaded_weight.to(param.device)
weight = _chunk_align(loaded_weight, world_size, 1,
self.tp_align_size)[rank]
return default_weight_loader(param, weight)
Expand Down
5 changes: 4 additions & 1 deletion lmdeploy/pytorch/nn/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ def weight_loader_tp(self, param: torch.nn.Parameter,
weight = loaded_weight.chunk(world_size, dim=0)[rank]
elif shard_id == 'down':
param_data = param.data[expert_id]
weight = loaded_weight.chunk(world_size, dim=1)[rank]
# weight is not contiguous, chunk and copy in cpu is slow
weight = loaded_weight.to(param_data.device)
weight = weight.chunk(world_size, dim=1)[rank]
else:
raise RuntimeError(f'Unknown shard_id: {shard_id}')
param_data.copy_(weight)
Expand Down Expand Up @@ -403,6 +405,7 @@ def weight_loader_scale_tp(self, param: torch.nn.Parameter,
weight = loaded_weight.chunk(world_size, dim=0)[rank]
elif shard_id == 'down':
param_data = param.data[expert_id]
loaded_weight = loaded_weight.to(param_data.device)
weight = loaded_weight.chunk(world_size, dim=1)[rank]
else:
raise RuntimeError(f'Unknown shard_id: {shard_id}')
Expand Down
62 changes: 41 additions & 21 deletions lmdeploy/pytorch/weight_loader/model_weight_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import os.path as osp

import torch
from transformers.modeling_utils import load_state_dict
from safetensors.torch import safe_open
from tqdm.auto import tqdm
from transformers.utils import (SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME, WEIGHTS_NAME)

Expand Down Expand Up @@ -90,6 +91,28 @@ def _get_weight_path(model_path: str, weight_type: str):
return weight_path, weight_name


def _get_safetensors_weights_iterator(file: str, prefix: str):
"""get safeternsors weights iterator."""
with safe_open(file, framework='pt') as f:
for name in f.keys():
param = f.get_tensor(name)
if prefix is not None:
name = f'{prefix}{name}'
yield name, param


def _get_pt_weights_iterator(file: str, prefix: str):
"""get pt weights iterator."""
state = torch.load(file, weights_only=True, map_location='cpu')
if prefix is None:
yield from state.items()
else:
for k, v in state.items():
yield f'{prefix}{k}', v
del state
torch.cuda.empty_cache()


class ModelWeightLoader:
"""model weight loader for sharded weights."""

Expand All @@ -115,13 +138,14 @@ def _get_shard_paths(model_path: str, is_sharded: bool, weight_type: str):
path, _ = _get_weight_path(model_path, weight_type)
return (path, )

def _load_shard(self, path: str):
"""load shards."""
state_dict = load_state_dict(path)
if self._prefix is not None:
state_dict = dict(
(f'{self._prefix}{k}', v) for k, v in state_dict.items())
return state_dict
def _get_weights_iterator(self, path: str):
"""get weights iterator."""
if self._weight_type == 'safetensors':
weights_iterator = _get_safetensors_weights_iterator(
path, self._prefix)
else:
weights_iterator = _get_pt_weights_iterator(path, self._prefix)
return weights_iterator

def load_model_weights(
self,
Expand All @@ -131,19 +155,15 @@ def load_model_weights(
"""load model weights implementation."""
assert hasattr(model, 'load_weights')
paths = self._shard_paths
world_size, rank = get_world_rank()
for path in paths:

# log
file_name = osp.split(path)[1]
msg = f'loading weights - "{file_name}"'
if world_size > 1:
msg = f'rank[{rank}] {msg}'
logger.info(msg)

# process
state_dict = self._load_shard(path)
model.load_weights(state_dict.items())
_, rank = get_world_rank()
disable_tqdm = rank != 0

paths = sorted(paths)
for path in tqdm(paths,
desc='Loading weights from safetensors',
disable=disable_tqdm):
weights_iterator = self._get_weights_iterator(path)
model.load_weights(weights_iterator)
if device is not None:
device = model.to(device)

Expand Down