Skip to content

Commit c96a535

Browse files
authored
[Feature] support qwen3-embedding model load (#4202)
* support qwen3-embedding * fix ci bug * fix * fix ci bug * fix ci bug * fix
1 parent 9082f62 commit c96a535

File tree

5 files changed

+315
-63
lines changed

5 files changed

+315
-63
lines changed

fastdeploy/model_executor/layers/embeddings.py

Lines changed: 174 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515
"""
1616

17+
from dataclasses import dataclass
1718
from typing import Dict
1819

1920
import numpy as np
@@ -22,9 +23,73 @@
2223
from paddle.distributed import fleet
2324

2425
from fastdeploy.config import FDConfig
25-
from fastdeploy.model_executor.utils import set_weight_attrs
26+
from fastdeploy.model_executor.utils import set_weight_attrs, slice_fn
2627

27-
from .utils import get_tensor
28+
from .utils import (
29+
DEFAULT_VOCAB_PADDING_SIZE,
30+
get_tensor,
31+
pad_vocab_size,
32+
vocab_range_from_global_vocab_size,
33+
)
34+
35+
36+
@dataclass
37+
class VocabParallelEmbeddingShardIndices:
38+
"""Indices for a shard of a vocab parallel embedding."""
39+
40+
padded_org_vocab_start_index: int
41+
padded_org_vocab_end_index: int
42+
padded_added_vocab_start_index: int
43+
padded_added_vocab_end_index: int
44+
45+
org_vocab_start_index: int
46+
org_vocab_end_index: int
47+
added_vocab_start_index: int
48+
added_vocab_end_index: int
49+
50+
@property
51+
def num_org_elements(self) -> int:
52+
return self.org_vocab_end_index - self.org_vocab_start_index
53+
54+
@property
55+
def num_added_elements(self) -> int:
56+
return self.added_vocab_end_index - self.added_vocab_start_index
57+
58+
@property
59+
def num_org_elements_padded(self) -> int:
60+
return self.padded_org_vocab_end_index - self.padded_org_vocab_start_index
61+
62+
@property
63+
def num_added_elements_padded(self) -> int:
64+
return self.padded_added_vocab_end_index - self.padded_added_vocab_start_index
65+
66+
@property
67+
def num_org_vocab_padding(self) -> int:
68+
return self.num_org_elements_padded - self.num_org_elements
69+
70+
@property
71+
def num_added_vocab_padding(self) -> int:
72+
return self.num_added_elements_padded - self.num_added_elements
73+
74+
@property
75+
def num_elements_padded(self) -> int:
76+
return self.num_org_elements_padded + self.num_added_elements_padded
77+
78+
def __post_init__(self):
79+
# sanity checks
80+
assert self.padded_org_vocab_start_index <= self.padded_org_vocab_end_index
81+
assert self.padded_added_vocab_start_index <= self.padded_added_vocab_end_index
82+
83+
assert self.org_vocab_start_index <= self.org_vocab_end_index
84+
assert self.added_vocab_start_index <= self.added_vocab_end_index
85+
86+
assert self.org_vocab_start_index <= self.padded_org_vocab_start_index
87+
assert self.added_vocab_start_index <= self.padded_added_vocab_start_index
88+
assert self.org_vocab_end_index <= self.padded_org_vocab_end_index
89+
assert self.added_vocab_end_index <= self.padded_added_vocab_end_index
90+
91+
assert self.num_org_elements <= self.num_org_elements_padded
92+
assert self.num_added_elements <= self.num_added_elements_padded
2893

2994

3095
class VocabParallelEmbedding(nn.Layer):
@@ -39,6 +104,7 @@ def __init__(
39104
embedding_dim: int = 768,
40105
params_dtype: str = "bfloat16",
41106
prefix="",
107+
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
42108
) -> None:
43109
"""
44110
Initialize the VocabParallelEmbedding layer for the model.
@@ -65,18 +131,40 @@ def __init__(
65131
self.max_position_embeddings: int = fd_config.model_config.max_position_embeddings
66132
self.tie_word_embeddings: bool = fd_config.model_config.tie_word_embeddings
67133
self.params_dtype: str = params_dtype
134+
self.padding_size = padding_size
135+
136+
self.org_vocab_size = num_embeddings
137+
self.num_embeddings = num_embeddings
138+
num_added_embeddings = num_embeddings - self.org_vocab_size
139+
140+
self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size, self.padding_size)
141+
self.num_embeddings_padded = pad_vocab_size(
142+
self.org_vocab_size_padded + num_added_embeddings, self.padding_size
143+
)
144+
assert self.org_vocab_size_padded <= self.num_embeddings_padded
145+
self.shard_indices = self._get_indices(
146+
self.num_embeddings_padded,
147+
self.org_vocab_size_padded,
148+
self.num_embeddings,
149+
self.org_vocab_size,
150+
self.tensor_parallel_rank,
151+
self.world_size,
152+
)
153+
154+
if num_embeddings % self.world_size != 0:
155+
self.num_embeddings_padded = pad_vocab_size(num_embeddings, self.padding_size)
68156

69157
if not self.column_cut:
70158
self.embeddings = fleet.meta_parallel.VocabParallelEmbedding(
71-
num_embeddings,
159+
self.num_embeddings_padded,
72160
embedding_dim,
73161
mp_group=self.tp_group,
74162
weight_attr=paddle.ParamAttr(
75163
initializer=nn.initializer.Normal(mean=0.0, std=self.initializer_range),
76164
),
77165
)
78166
if self.world_size > 1:
79-
set_weight_attrs(self.embeddings.weight, {"output_dim": False})
167+
set_weight_attrs(self.embeddings.weight, {"output_dim": False, "weight_loader": self.weight_loader})
80168
else:
81169
# column cut embedding
82170
self.embeddings = nn.Embedding(
@@ -106,6 +194,88 @@ def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
106194

107195
self.embeddings.weight.set_value(weight_tensor)
108196

197+
@classmethod
198+
def _get_indices(
199+
cls,
200+
vocab_size_paded: int,
201+
org_vocab_size_padded: int,
202+
vocab_size: int,
203+
org_vocab_size: int,
204+
tp_rank: int,
205+
tp_size: int,
206+
) -> VocabParallelEmbeddingShardIndices:
207+
"""Get start and end indices for vocab parallel embedding, following the
208+
layout outlined in the class docstring, based on the given tp_rank and
209+
tp_size."""
210+
211+
num_added_embeddings_padded = vocab_size_paded - org_vocab_size_padded
212+
padded_org_vocab_start_index, padded_org_vocab_end_index = vocab_range_from_global_vocab_size(
213+
org_vocab_size_padded, tp_rank, tp_size
214+
)
215+
216+
padded_added_vocab_start_index, padded_added_vocab_end_index = vocab_range_from_global_vocab_size(
217+
num_added_embeddings_padded, tp_rank, tp_size, offset=org_vocab_size
218+
)
219+
# remove padding
220+
org_vocab_start_index = min(padded_org_vocab_start_index, org_vocab_size)
221+
org_vocab_end_index = min(padded_org_vocab_end_index, org_vocab_size)
222+
added_vocab_start_index = min(padded_added_vocab_start_index, vocab_size)
223+
added_vocab_end_index = min(padded_added_vocab_end_index, vocab_size)
224+
return VocabParallelEmbeddingShardIndices(
225+
padded_org_vocab_start_index,
226+
padded_org_vocab_end_index,
227+
padded_added_vocab_start_index,
228+
padded_added_vocab_end_index,
229+
org_vocab_start_index,
230+
org_vocab_end_index,
231+
added_vocab_start_index,
232+
added_vocab_end_index,
233+
)
234+
235+
def weight_loader(self, param, loaded_weight, shard_id=None):
236+
output_dim = getattr(param, "output_dim", None)
237+
packed_dim = getattr(param, "packed_dim", None)
238+
239+
loaded_weight = get_tensor(loaded_weight)
240+
if param.dtype != loaded_weight.dtype:
241+
if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn:
242+
loaded_weight = loaded_weight.cast(param.dtype)
243+
else:
244+
loaded_weight = loaded_weight.cast(param.dtype)
245+
246+
if output_dim is None:
247+
assert (
248+
param.shape == loaded_weight.shape
249+
), f"Shape mismatch: param {param.shape} vs loaded_weight {loaded_weight.shape}"
250+
param.set_value(loaded_weight)
251+
return
252+
253+
start_idx = self.shard_indices.org_vocab_start_index
254+
end_idx = self.shard_indices.org_vocab_end_index
255+
shard_size = self.shard_indices.org_vocab_end_index - start_idx
256+
257+
# If param packed on the same dim we are sharding on, then
258+
# need to adjust offsets of loaded weight by pack_factor.
259+
if packed_dim is not None and packed_dim == output_dim:
260+
packed_factor = getattr(param, "packed_factor", getattr(param, "pack_factor", 1))
261+
assert loaded_weight.shape[output_dim] == (self.org_vocab_size // packed_factor)
262+
start_idx = start_idx // packed_factor
263+
shard_size = shard_size // packed_factor
264+
else:
265+
assert loaded_weight.shape[output_dim] == self.org_vocab_size, (
266+
f"Loaded weight dim {output_dim} size {loaded_weight.shape[output_dim]} "
267+
f"!= org_vocab_size {self.org_vocab_size}"
268+
)
269+
270+
shard_weight = slice_fn(loaded_weight, output_dim, start_idx, end_idx)
271+
272+
if output_dim == 0:
273+
param[: shard_weight.shape[0]].copy_(shard_weight, False)
274+
param[shard_weight.shape[0] :].fill_(0)
275+
else:
276+
param[:, : shard_weight.shape[1]].copy_(shard_weight, False)
277+
param[:, shard_weight.shape[1] :].fill_(0)
278+
109279
def forward(self, ids_remove_padding=None) -> paddle.Tensor:
110280
"""
111281
Defines the forward computation of the layer.

fastdeploy/model_executor/layers/lm_head.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@
2222
from paddle.distributed import fleet
2323

2424
from fastdeploy.config import FDConfig
25+
from fastdeploy.model_executor.layers.utils import (
26+
DEFAULT_VOCAB_PADDING_SIZE,
27+
pad_vocab_size,
28+
)
2529
from fastdeploy.model_executor.utils import (
2630
default_weight_loader,
2731
set_weight_attrs,
@@ -44,6 +48,7 @@ def __init__(
4448
prefix: str = "",
4549
with_bias: bool = False,
4650
dtype: str = None,
51+
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
4752
) -> None:
4853
"""
4954
Parallelized LMhead.
@@ -68,6 +73,10 @@ def __init__(
6873
self.column_cut = True
6974
self.nranks = fd_config.parallel_config.tensor_parallel_size
7075
self.fd_config = fd_config
76+
self.padding_size = padding_size
77+
78+
if num_embeddings % self.nranks != 0:
79+
num_embeddings = pad_vocab_size(num_embeddings, self.padding_size)
7180

7281
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
7382
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
"""
2+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License"
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""

fastdeploy/model_executor/layers/utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,14 @@
4545
c8_state_dict = paddle.load(cache_params, return_numpy=True)
4646

4747

48+
DEFAULT_VOCAB_PADDING_SIZE = 64
49+
50+
51+
def pad_vocab_size(vocab_size: int, pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
52+
"""Pad the vocab size to the given value."""
53+
return ((vocab_size + pad_to - 1) // pad_to) * pad_to
54+
55+
4856
def per_block_cast_to_fp8(x: Tensor, block_size: list = [128, 128]) -> Tuple[Tensor, Tensor]:
4957
"""
5058
Only used in deep_gemm block wise quant weight.
@@ -372,3 +380,14 @@ def create_empty_tensor(shape: Tuple[int, ...], dtype: Union[paddle.dtype, str])
372380
paddle.Tensor: An empty tensor with the specified shape and data type.
373381
"""
374382
return paddle.empty(list(shape), dtype=dtype)
383+
384+
385+
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size: int, rank: int, offset: int = 0):
386+
index_f = rank * per_partition_vocab_size
387+
index_l = index_f + per_partition_vocab_size
388+
return index_f + offset, index_l + offset
389+
390+
391+
def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int, world_size: int, offset: int = 0):
392+
per_partition_vocab_size = divide(global_vocab_size, world_size)
393+
return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, offset=offset)

0 commit comments

Comments
 (0)