1414# limitations under the License.
1515"""
1616
17+ from dataclasses import dataclass
1718from typing import Dict
1819
1920import numpy as np
2223from paddle .distributed import fleet
2324
2425from fastdeploy .config import FDConfig
25- from fastdeploy .model_executor .utils import set_weight_attrs
26+ from fastdeploy .model_executor .utils import set_weight_attrs , slice_fn
2627
27- from .utils import get_tensor
28+ from .utils import (
29+ DEFAULT_VOCAB_PADDING_SIZE ,
30+ get_tensor ,
31+ pad_vocab_size ,
32+ vocab_range_from_global_vocab_size ,
33+ )
34+
35+
36+ @dataclass
37+ class VocabParallelEmbeddingShardIndices :
38+ """Indices for a shard of a vocab parallel embedding."""
39+
40+ padded_org_vocab_start_index : int
41+ padded_org_vocab_end_index : int
42+ padded_added_vocab_start_index : int
43+ padded_added_vocab_end_index : int
44+
45+ org_vocab_start_index : int
46+ org_vocab_end_index : int
47+ added_vocab_start_index : int
48+ added_vocab_end_index : int
49+
50+ @property
51+ def num_org_elements (self ) -> int :
52+ return self .org_vocab_end_index - self .org_vocab_start_index
53+
54+ @property
55+ def num_added_elements (self ) -> int :
56+ return self .added_vocab_end_index - self .added_vocab_start_index
57+
58+ @property
59+ def num_org_elements_padded (self ) -> int :
60+ return self .padded_org_vocab_end_index - self .padded_org_vocab_start_index
61+
62+ @property
63+ def num_added_elements_padded (self ) -> int :
64+ return self .padded_added_vocab_end_index - self .padded_added_vocab_start_index
65+
66+ @property
67+ def num_org_vocab_padding (self ) -> int :
68+ return self .num_org_elements_padded - self .num_org_elements
69+
70+ @property
71+ def num_added_vocab_padding (self ) -> int :
72+ return self .num_added_elements_padded - self .num_added_elements
73+
74+ @property
75+ def num_elements_padded (self ) -> int :
76+ return self .num_org_elements_padded + self .num_added_elements_padded
77+
78+ def __post_init__ (self ):
79+ # sanity checks
80+ assert self .padded_org_vocab_start_index <= self .padded_org_vocab_end_index
81+ assert self .padded_added_vocab_start_index <= self .padded_added_vocab_end_index
82+
83+ assert self .org_vocab_start_index <= self .org_vocab_end_index
84+ assert self .added_vocab_start_index <= self .added_vocab_end_index
85+
86+ assert self .org_vocab_start_index <= self .padded_org_vocab_start_index
87+ assert self .added_vocab_start_index <= self .padded_added_vocab_start_index
88+ assert self .org_vocab_end_index <= self .padded_org_vocab_end_index
89+ assert self .added_vocab_end_index <= self .padded_added_vocab_end_index
90+
91+ assert self .num_org_elements <= self .num_org_elements_padded
92+ assert self .num_added_elements <= self .num_added_elements_padded
2893
2994
3095class VocabParallelEmbedding (nn .Layer ):
@@ -39,6 +104,7 @@ def __init__(
39104 embedding_dim : int = 768 ,
40105 params_dtype : str = "bfloat16" ,
41106 prefix = "" ,
107+ padding_size : int = DEFAULT_VOCAB_PADDING_SIZE ,
42108 ) -> None :
43109 """
44110 Initialize the VocabParallelEmbedding layer for the model.
@@ -65,18 +131,40 @@ def __init__(
65131 self .max_position_embeddings : int = fd_config .model_config .max_position_embeddings
66132 self .tie_word_embeddings : bool = fd_config .model_config .tie_word_embeddings
67133 self .params_dtype : str = params_dtype
134+ self .padding_size = padding_size
135+
136+ self .org_vocab_size = num_embeddings
137+ self .num_embeddings = num_embeddings
138+ num_added_embeddings = num_embeddings - self .org_vocab_size
139+
140+ self .org_vocab_size_padded = pad_vocab_size (self .org_vocab_size , self .padding_size )
141+ self .num_embeddings_padded = pad_vocab_size (
142+ self .org_vocab_size_padded + num_added_embeddings , self .padding_size
143+ )
144+ assert self .org_vocab_size_padded <= self .num_embeddings_padded
145+ self .shard_indices = self ._get_indices (
146+ self .num_embeddings_padded ,
147+ self .org_vocab_size_padded ,
148+ self .num_embeddings ,
149+ self .org_vocab_size ,
150+ self .tensor_parallel_rank ,
151+ self .world_size ,
152+ )
153+
154+ if num_embeddings % self .world_size != 0 :
155+ self .num_embeddings_padded = pad_vocab_size (num_embeddings , self .padding_size )
68156
69157 if not self .column_cut :
70158 self .embeddings = fleet .meta_parallel .VocabParallelEmbedding (
71- num_embeddings ,
159+ self . num_embeddings_padded ,
72160 embedding_dim ,
73161 mp_group = self .tp_group ,
74162 weight_attr = paddle .ParamAttr (
75163 initializer = nn .initializer .Normal (mean = 0.0 , std = self .initializer_range ),
76164 ),
77165 )
78166 if self .world_size > 1 :
79- set_weight_attrs (self .embeddings .weight , {"output_dim" : False })
167+ set_weight_attrs (self .embeddings .weight , {"output_dim" : False , "weight_loader" : self . weight_loader })
80168 else :
81169 # column cut embedding
82170 self .embeddings = nn .Embedding (
@@ -106,6 +194,88 @@ def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
106194
107195 self .embeddings .weight .set_value (weight_tensor )
108196
197+ @classmethod
198+ def _get_indices (
199+ cls ,
200+ vocab_size_paded : int ,
201+ org_vocab_size_padded : int ,
202+ vocab_size : int ,
203+ org_vocab_size : int ,
204+ tp_rank : int ,
205+ tp_size : int ,
206+ ) -> VocabParallelEmbeddingShardIndices :
207+ """Get start and end indices for vocab parallel embedding, following the
208+ layout outlined in the class docstring, based on the given tp_rank and
209+ tp_size."""
210+
211+ num_added_embeddings_padded = vocab_size_paded - org_vocab_size_padded
212+ padded_org_vocab_start_index , padded_org_vocab_end_index = vocab_range_from_global_vocab_size (
213+ org_vocab_size_padded , tp_rank , tp_size
214+ )
215+
216+ padded_added_vocab_start_index , padded_added_vocab_end_index = vocab_range_from_global_vocab_size (
217+ num_added_embeddings_padded , tp_rank , tp_size , offset = org_vocab_size
218+ )
219+ # remove padding
220+ org_vocab_start_index = min (padded_org_vocab_start_index , org_vocab_size )
221+ org_vocab_end_index = min (padded_org_vocab_end_index , org_vocab_size )
222+ added_vocab_start_index = min (padded_added_vocab_start_index , vocab_size )
223+ added_vocab_end_index = min (padded_added_vocab_end_index , vocab_size )
224+ return VocabParallelEmbeddingShardIndices (
225+ padded_org_vocab_start_index ,
226+ padded_org_vocab_end_index ,
227+ padded_added_vocab_start_index ,
228+ padded_added_vocab_end_index ,
229+ org_vocab_start_index ,
230+ org_vocab_end_index ,
231+ added_vocab_start_index ,
232+ added_vocab_end_index ,
233+ )
234+
235+ def weight_loader (self , param , loaded_weight , shard_id = None ):
236+ output_dim = getattr (param , "output_dim" , None )
237+ packed_dim = getattr (param , "packed_dim" , None )
238+
239+ loaded_weight = get_tensor (loaded_weight )
240+ if param .dtype != loaded_weight .dtype :
241+ if loaded_weight .dtype == paddle .int8 and param .dtype == paddle .float8_e4m3fn :
242+ loaded_weight = loaded_weight .cast (param .dtype )
243+ else :
244+ loaded_weight = loaded_weight .cast (param .dtype )
245+
246+ if output_dim is None :
247+ assert (
248+ param .shape == loaded_weight .shape
249+ ), f"Shape mismatch: param { param .shape } vs loaded_weight { loaded_weight .shape } "
250+ param .set_value (loaded_weight )
251+ return
252+
253+ start_idx = self .shard_indices .org_vocab_start_index
254+ end_idx = self .shard_indices .org_vocab_end_index
255+ shard_size = self .shard_indices .org_vocab_end_index - start_idx
256+
257+ # If param packed on the same dim we are sharding on, then
258+ # need to adjust offsets of loaded weight by pack_factor.
259+ if packed_dim is not None and packed_dim == output_dim :
260+ packed_factor = getattr (param , "packed_factor" , getattr (param , "pack_factor" , 1 ))
261+ assert loaded_weight .shape [output_dim ] == (self .org_vocab_size // packed_factor )
262+ start_idx = start_idx // packed_factor
263+ shard_size = shard_size // packed_factor
264+ else :
265+ assert loaded_weight .shape [output_dim ] == self .org_vocab_size , (
266+ f"Loaded weight dim { output_dim } size { loaded_weight .shape [output_dim ]} "
267+ f"!= org_vocab_size { self .org_vocab_size } "
268+ )
269+
270+ shard_weight = slice_fn (loaded_weight , output_dim , start_idx , end_idx )
271+
272+ if output_dim == 0 :
273+ param [: shard_weight .shape [0 ]].copy_ (shard_weight , False )
274+ param [shard_weight .shape [0 ] :].fill_ (0 )
275+ else :
276+ param [:, : shard_weight .shape [1 ]].copy_ (shard_weight , False )
277+ param [:, shard_weight .shape [1 ] :].fill_ (0 )
278+
109279 def forward (self , ids_remove_padding = None ) -> paddle .Tensor :
110280 """
111281 Defines the forward computation of the layer.
0 commit comments