Skip to content

Commit 233ca08

Browse files
committed
refactor pt loading
1 parent 0a0c74e commit 233ca08

21 files changed

+909
-703
lines changed

fastdeploy/model_executor/layers/linear.py

Lines changed: 43 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase
2626
from fastdeploy.model_executor.utils import (
2727
default_weight_loader,
28+
process_weight_transpose,
2829
set_weight_attrs,
2930
slice_fn,
3031
)
@@ -43,24 +44,36 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
4344
- output_dim: determines whether the split is applied along the output dimension (rows) or input dimension (columns)
4445
- weight_loader: a callable or method responsible for loading the weight data
4546
"""
47+
self.model_format = extra_weight_attrs.get("model_format")
48+
self.weight_shape = (
49+
layer.weight_shape[::-1] if extra_weight_attrs.get("model_format") == "torch" else layer.weight_shape
50+
)
51+
4652
layer.weight = layer.create_parameter(
47-
shape=layer.weight_shape,
53+
shape=self.weight_shape,
4854
dtype=layer.weight_dtype,
4955
is_bias=False,
5056
default_initializer=paddle.nn.initializer.Constant(0),
5157
)
5258
split_axis = extra_weight_attrs.get("split_axis")
5359
if hasattr(layer, "nranks") and layer.nranks > 0:
5460
_set_var_distributed(layer.weight, split_axis=split_axis)
61+
62+
if self.model_format == "torch" and "output_dim" in extra_weight_attrs:
63+
extra_weight_attrs["output_dim"] = not extra_weight_attrs["output_dim"]
64+
5565
set_weight_attrs(
5666
layer.weight,
5767
{
5868
**extra_weight_attrs,
5969
"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config)),
60-
"weight_need_transpose": extra_weight_attrs.get("model_format") == "torch",
6170
},
6271
)
6372

73+
def process_weights_after_loading(self, layer):
74+
if self.model_format == "torch":
75+
process_weight_transpose(layer, "weight")
76+
6477
def process_loaded_weights(self, layer, weights) -> None:
6578
# mlp.gate.weight is precision-sensitive, so we cast it to float32 for computation
6679
if layer.weight.dtype != weights.dtype:
@@ -165,7 +178,7 @@ def __init__(
165178
if self.with_bias:
166179
self.bias = self.create_parameter(
167180
shape=[self.output_size],
168-
dtype=self._dtype,
181+
dtype=self.weight_dtype,
169182
is_bias=True,
170183
)
171184
setattr(
@@ -262,6 +275,7 @@ def __init__(
262275
skip_quant: bool = False,
263276
weight_dtype: str = "",
264277
weight_key: str = "",
278+
model_format: Optional[str] = None,
265279
):
266280
"""
267281
Initializes a replicated linear layer.
@@ -296,7 +310,7 @@ def __init__(
296310
weight_loader=(
297311
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
298312
),
299-
model_format=fd_config.model_config.model_format,
313+
model_format=fd_config.model_config.model_format if model_format is None else model_format,
300314
)
301315

302316

@@ -344,7 +358,6 @@ def __init__(
344358

345359
def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None):
346360
weight_need_transpose = getattr(param, "weight_need_transpose", False)
347-
loaded_weight = get_tensor(loaded_weight)
348361

349362
if weight_need_transpose:
350363
loaded_weight = loaded_weight.transpose([1, 0])
@@ -393,7 +406,7 @@ def __init__(
393406
with_bias: bool = False,
394407
add_bias: bool = False,
395408
skip_quant: bool = False,
396-
weight_dtype="",
409+
weight_dtype: str = "",
397410
):
398411
"""
399412
Initializes a linear layer and provides additional parameters required for inference and quantization.
@@ -500,7 +513,6 @@ def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = N
500513
output_size = param.shape[shard_dim]
501514
if loaded_shard_id is None:
502515
if weight_need_transpose:
503-
loaded_weight = get_tensor(loaded_weight)
504516
loaded_weight = loaded_weight.transpose([1, 0])
505517
# Avoid redundant transpose of fused weights when weight_loader is called iteratively
506518
param.weight_need_transpose = False
@@ -519,7 +531,6 @@ def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = N
519531
# split gate up
520532
assert loaded_shard_id in ["gate", "up"]
521533
if weight_need_transpose:
522-
loaded_weight = get_tensor(loaded_weight)
523534
loaded_weight = loaded_weight.transpose([1, 0])
524535
# Tensor parallelism splits the weight along the output_dim
525536
if self.nranks != 1:
@@ -532,7 +543,6 @@ def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = N
532543
shard_offset = self.local_rank * block_size
533544
shard_size = (self.local_rank + 1) * block_size
534545
loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_size)
535-
loaded_weight = get_tensor(loaded_weight)
536546
if not param._is_initialized():
537547
param.initialize()
538548
param_shard_size = output_size // 2
@@ -589,7 +599,19 @@ class QKVParallelLinear(ColumnParallelLinear):
589599
QKVParallelLinear Layer.
590600
"""
591601

592-
def __init__(self, fd_config, prefix, with_bias=False, add_bias=True):
602+
def __init__(
603+
self,
604+
fd_config,
605+
prefix,
606+
with_bias=False,
607+
add_bias=True,
608+
num_heads: Optional[int] = None,
609+
kv_num_heads: Optional[int] = None,
610+
hidden_size: Optional[int] = None,
611+
head_dim: Optional[int] = None,
612+
skip_quant: bool = False,
613+
weight_dtype: str = "",
614+
):
593615
"""
594616
Initialize the QKV Linear layer with given parameters.
595617
@@ -599,11 +621,15 @@ def __init__(self, fd_config, prefix, with_bias=False, add_bias=True):
599621
Can be arbitrarily named.
600622
with_bias (bool): Whether to include bias or not. Defaults to False.
601623
add_bias (bool): Whether to add bias in the current layer or in the pre/post layer. Defaults to True.
624+
num_heads (Optional[int]): Number of attention heads in the model.
625+
kv_num_heads (Optional[int]): Number of key/value heads, used for multi-query or grouped-query attention.
626+
hidden_size (Optional[int]): Total hidden layer dimension, typically the embedding size.
627+
head_dim (Optional[int]): Size of each attention head, usually computed as hidden_size divided by num_heads.
602628
"""
603-
self.num_heads = fd_config.model_config.num_attention_heads
604-
self.kv_num_heads = fd_config.model_config.num_key_value_heads
605-
self.hidden_size = fd_config.model_config.hidden_size
606-
self.head_dim = fd_config.model_config.head_dim
629+
self.num_heads = fd_config.model_config.num_attention_heads if num_heads is None else num_heads
630+
self.kv_num_heads = fd_config.model_config.num_key_value_heads if kv_num_heads is None else kv_num_heads
631+
self.hidden_size = fd_config.model_config.hidden_size if hidden_size is None else hidden_size
632+
self.head_dim = fd_config.model_config.head_dim if head_dim is None else head_dim
607633
self.nranks = fd_config.parallel_config.tensor_parallel_size
608634
self.local_rank = fd_config.parallel_config.tensor_parallel_rank
609635
self.num_heads_per_rank = divide(self.num_heads, self.nranks)
@@ -623,6 +649,8 @@ def __init__(self, fd_config, prefix, with_bias=False, add_bias=True):
623649
output_size=output_size,
624650
with_bias=with_bias,
625651
add_bias=add_bias,
652+
skip_quant=skip_quant,
653+
weight_dtype=weight_dtype,
626654
)
627655

628656
def _get_shard_size_mapping(self, loaded_shard_id: str, head_dim: int):
@@ -641,7 +669,6 @@ def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = N
641669
weight_need_transpose = getattr(param, "weight_need_transpose", False)
642670
if loaded_shard_id is None:
643671
if weight_need_transpose:
644-
loaded_weight = get_tensor(loaded_weight)
645672
loaded_weight = loaded_weight.transpose([1, 0])
646673
# Avoid redundant transpose of fused weights when weight_loader is called iteratively
647674
param.weight_need_transpose = False
@@ -661,7 +688,6 @@ def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = N
661688
# split q k v
662689
assert loaded_shard_id in ["q", "k", "v"]
663690
if weight_need_transpose:
664-
loaded_weight = get_tensor(loaded_weight)
665691
loaded_weight = loaded_weight.transpose([1, 0])
666692
# Tensor parallelism splits the weight along the output_dim
667693
if self.nranks != 1:
@@ -671,8 +697,6 @@ def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = N
671697
shard_size = block_size
672698
loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_offset + shard_size)
673699

674-
loaded_weight = get_tensor(loaded_weight)
675-
676700
if not param._is_initialized():
677701
param.initialize()
678702

@@ -798,7 +822,7 @@ def __init__(
798822
add_bias: bool = False,
799823
reduce_results: bool = True,
800824
skip_quant: bool = False,
801-
weight_dtype="",
825+
weight_dtype: str = "",
802826
):
803827
"""
804828
Initialize a linear layer with additional parameters for inference and quantization.
@@ -847,10 +871,6 @@ def __init__(
847871
),
848872
model_format=fd_config.model_config.model_format,
849873
)
850-
if self.nranks > 0:
851-
if self.with_bias:
852-
# col parallel
853-
_set_var_distributed(self.bias, split_axis=0)
854874

855875
self.reduce_results = reduce_results
856876

fastdeploy/model_executor/layers/lm_head.py

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
)
2929
from fastdeploy.model_executor.utils import (
3030
default_weight_loader,
31+
free_tensor,
3132
set_weight_attrs,
3233
temporary_dtype,
3334
)
@@ -69,6 +70,7 @@ def __init__(
6970
self.bias_key: Optional[str] = prefix + ".bias"
7071
else:
7172
self.bias_key: Optional[str] = None
73+
self.embedding_dim = embedding_dim
7274
self.tp_group = fd_config.parallel_config.tp_group
7375
self.column_cut = True
7476
self.nranks = fd_config.parallel_config.tensor_parallel_size
@@ -77,34 +79,51 @@ def __init__(
7779

7880
if num_embeddings % self.nranks != 0:
7981
num_embeddings = pad_vocab_size(num_embeddings, self.padding_size)
82+
self.num_embeddings = num_embeddings
83+
self.model_format = fd_config.model_config.model_format
8084

8185
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
8286
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
8387
self.dtype = "float32" if fd_config.model_config.lm_head_fp32 else dtype
8488

8589
self.tie_word_embeddings: bool = fd_config.model_config.tie_word_embeddings
90+
self.need_gather = True
8691

8792
with temporary_dtype(self.dtype):
88-
if self.column_cut:
89-
need_gather = True
93+
if self.fd_config.load_config.load_choices == "default_v1" and self.model_format == "torch":
94+
self.linear = RowParallelLinear(
95+
num_embeddings,
96+
embedding_dim,
97+
mp_group=self.tp_group,
98+
weight_attr=None,
99+
has_bias=True if self.bias_key is not None else False,
100+
input_is_parallel=False,
101+
fuse_matmul_bias=False,
102+
)
103+
set_weight_attrs(
104+
self.linear.weight,
105+
{
106+
"weight_loader": default_weight_loader(self.fd_config),
107+
},
108+
)
109+
set_weight_attrs(self.linear.weight, {"output_dim": False})
110+
elif self.column_cut:
90111
self.linear = ColumnParallelLinear(
91112
embedding_dim,
92113
num_embeddings,
93114
mp_group=self.tp_group,
94115
weight_attr=None,
95116
has_bias=True if self.bias_key is not None else False,
96-
gather_output=need_gather,
117+
gather_output=self.need_gather,
97118
fuse_matmul_bias=False,
98119
)
99120
set_weight_attrs(
100121
self.linear.weight,
101122
{
102123
"weight_loader": default_weight_loader(self.fd_config),
103-
"weight_need_transpose": self.fd_config.model_config.model_format == "torch",
104124
},
105125
)
106-
if self.nranks > 1:
107-
set_weight_attrs(self.linear.weight, {"output_dim": True})
126+
set_weight_attrs(self.linear.weight, {"output_dim": True})
108127
else:
109128
self.linear = RowParallelLinear(
110129
embedding_dim,
@@ -119,12 +138,29 @@ def __init__(
119138
self.linear.weight,
120139
{
121140
"weight_loader": default_weight_loader(self.fd_config),
122-
"weight_need_transpose": self.fd_config.model_config.model_format == "torch",
123141
},
124142
)
125-
126-
if self.nranks > 1:
127-
set_weight_attrs(self.linear.weight, {"output_dim": False})
143+
set_weight_attrs(self.linear.weight, {"output_dim": False})
144+
145+
def process_weights_after_loading(self):
146+
if self.model_format != "torch":
147+
return
148+
if not self.linear.weight._is_initialized():
149+
self.linear.weight.initialize()
150+
weight_transpose = self.linear.weight.transpose([1, 0])
151+
with temporary_dtype(self.dtype):
152+
linear = fleet.meta_parallel.ColumnParallelLinear(
153+
self.embedding_dim,
154+
self.num_embeddings,
155+
mp_group=self.tp_group,
156+
weight_attr=None,
157+
has_bias=True if self.bias_key is not None else False,
158+
gather_output=self.need_gather,
159+
fuse_matmul_bias=False,
160+
)
161+
linear.weight.set_value(weight_transpose)
162+
free_tensor(self.linear.weight)
163+
self.linear = linear
128164

129165
def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
130166
"""

0 commit comments

Comments
 (0)