2525from fastdeploy .model_executor .layers .quantization .quant_base import QuantMethodBase
2626from fastdeploy .model_executor .utils import (
2727 default_weight_loader ,
28+ process_weight_transpose ,
2829 set_weight_attrs ,
2930 slice_fn ,
3031)
@@ -43,24 +44,36 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
4344 - output_dim: determines whether the split is applied along the output dimension (rows) or input dimension (columns)
4445 - weight_loader: a callable or method responsible for loading the weight data
4546 """
47+ self .model_format = extra_weight_attrs .get ("model_format" )
48+ self .weight_shape = (
49+ layer .weight_shape [::- 1 ] if extra_weight_attrs .get ("model_format" ) == "torch" else layer .weight_shape
50+ )
51+
4652 layer .weight = layer .create_parameter (
47- shape = layer .weight_shape ,
53+ shape = self .weight_shape ,
4854 dtype = layer .weight_dtype ,
4955 is_bias = False ,
5056 default_initializer = paddle .nn .initializer .Constant (0 ),
5157 )
5258 split_axis = extra_weight_attrs .get ("split_axis" )
5359 if hasattr (layer , "nranks" ) and layer .nranks > 0 :
5460 _set_var_distributed (layer .weight , split_axis = split_axis )
61+
62+ if self .model_format == "torch" and "output_dim" in extra_weight_attrs :
63+ extra_weight_attrs ["output_dim" ] = not extra_weight_attrs ["output_dim" ]
64+
5565 set_weight_attrs (
5666 layer .weight ,
5767 {
5868 ** extra_weight_attrs ,
5969 "weight_loader" : extra_weight_attrs .get ("weight_loader" , default_weight_loader (layer .fd_config )),
60- "weight_need_transpose" : extra_weight_attrs .get ("model_format" ) == "torch" ,
6170 },
6271 )
6372
73+ def process_weights_after_loading (self , layer ):
74+ if self .model_format == "torch" :
75+ process_weight_transpose (layer , "weight" )
76+
6477 def process_loaded_weights (self , layer , weights ) -> None :
6578 # mlp.gate.weight is precision-sensitive, so we cast it to float32 for computation
6679 if layer .weight .dtype != weights .dtype :
@@ -165,7 +178,7 @@ def __init__(
165178 if self .with_bias :
166179 self .bias = self .create_parameter (
167180 shape = [self .output_size ],
168- dtype = self ._dtype ,
181+ dtype = self .weight_dtype ,
169182 is_bias = True ,
170183 )
171184 setattr (
@@ -262,6 +275,7 @@ def __init__(
262275 skip_quant : bool = False ,
263276 weight_dtype : str = "" ,
264277 weight_key : str = "" ,
278+ model_format : Optional [str ] = None ,
265279 ):
266280 """
267281 Initializes a replicated linear layer.
@@ -296,7 +310,7 @@ def __init__(
296310 weight_loader = (
297311 self .weight_loader if hasattr (self , "weight_loader" ) else default_weight_loader (self .fd_config )
298312 ),
299- model_format = fd_config .model_config .model_format ,
313+ model_format = fd_config .model_config .model_format if model_format is None else model_format ,
300314 )
301315
302316
@@ -344,7 +358,6 @@ def __init__(
344358
345359 def weight_loader (self , param , loaded_weight , loaded_shard_id : Optional [str ] = None ):
346360 weight_need_transpose = getattr (param , "weight_need_transpose" , False )
347- loaded_weight = get_tensor (loaded_weight )
348361
349362 if weight_need_transpose :
350363 loaded_weight = loaded_weight .transpose ([1 , 0 ])
@@ -393,7 +406,7 @@ def __init__(
393406 with_bias : bool = False ,
394407 add_bias : bool = False ,
395408 skip_quant : bool = False ,
396- weight_dtype = "" ,
409+ weight_dtype : str = "" ,
397410 ):
398411 """
399412 Initializes a linear layer and provides additional parameters required for inference and quantization.
@@ -500,7 +513,6 @@ def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = N
500513 output_size = param .shape [shard_dim ]
501514 if loaded_shard_id is None :
502515 if weight_need_transpose :
503- loaded_weight = get_tensor (loaded_weight )
504516 loaded_weight = loaded_weight .transpose ([1 , 0 ])
505517 # Avoid redundant transpose of fused weights when weight_loader is called iteratively
506518 param .weight_need_transpose = False
@@ -519,7 +531,6 @@ def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = N
519531 # split gate up
520532 assert loaded_shard_id in ["gate" , "up" ]
521533 if weight_need_transpose :
522- loaded_weight = get_tensor (loaded_weight )
523534 loaded_weight = loaded_weight .transpose ([1 , 0 ])
524535 # Tensor parallelism splits the weight along the output_dim
525536 if self .nranks != 1 :
@@ -532,7 +543,6 @@ def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = N
532543 shard_offset = self .local_rank * block_size
533544 shard_size = (self .local_rank + 1 ) * block_size
534545 loaded_weight = slice_fn (loaded_weight , output_dim , start = shard_offset , end = shard_size )
535- loaded_weight = get_tensor (loaded_weight )
536546 if not param ._is_initialized ():
537547 param .initialize ()
538548 param_shard_size = output_size // 2
@@ -589,7 +599,19 @@ class QKVParallelLinear(ColumnParallelLinear):
589599 QKVParallelLinear Layer.
590600 """
591601
592- def __init__ (self , fd_config , prefix , with_bias = False , add_bias = True ):
602+ def __init__ (
603+ self ,
604+ fd_config ,
605+ prefix ,
606+ with_bias = False ,
607+ add_bias = True ,
608+ num_heads : Optional [int ] = None ,
609+ kv_num_heads : Optional [int ] = None ,
610+ hidden_size : Optional [int ] = None ,
611+ head_dim : Optional [int ] = None ,
612+ skip_quant : bool = False ,
613+ weight_dtype : str = "" ,
614+ ):
593615 """
594616 Initialize the QKV Linear layer with given parameters.
595617
@@ -599,11 +621,15 @@ def __init__(self, fd_config, prefix, with_bias=False, add_bias=True):
599621 Can be arbitrarily named.
600622 with_bias (bool): Whether to include bias or not. Defaults to False.
601623 add_bias (bool): Whether to add bias in the current layer or in the pre/post layer. Defaults to True.
624+ num_heads (Optional[int]): Number of attention heads in the model.
625+ kv_num_heads (Optional[int]): Number of key/value heads, used for multi-query or grouped-query attention.
626+ hidden_size (Optional[int]): Total hidden layer dimension, typically the embedding size.
627+ head_dim (Optional[int]): Size of each attention head, usually computed as hidden_size divided by num_heads.
602628 """
603- self .num_heads = fd_config .model_config .num_attention_heads
604- self .kv_num_heads = fd_config .model_config .num_key_value_heads
605- self .hidden_size = fd_config .model_config .hidden_size
606- self .head_dim = fd_config .model_config .head_dim
629+ self .num_heads = fd_config .model_config .num_attention_heads if num_heads is None else num_heads
630+ self .kv_num_heads = fd_config .model_config .num_key_value_heads if kv_num_heads is None else kv_num_heads
631+ self .hidden_size = fd_config .model_config .hidden_size if hidden_size is None else hidden_size
632+ self .head_dim = fd_config .model_config .head_dim if head_dim is None else head_dim
607633 self .nranks = fd_config .parallel_config .tensor_parallel_size
608634 self .local_rank = fd_config .parallel_config .tensor_parallel_rank
609635 self .num_heads_per_rank = divide (self .num_heads , self .nranks )
@@ -623,6 +649,8 @@ def __init__(self, fd_config, prefix, with_bias=False, add_bias=True):
623649 output_size = output_size ,
624650 with_bias = with_bias ,
625651 add_bias = add_bias ,
652+ skip_quant = skip_quant ,
653+ weight_dtype = weight_dtype ,
626654 )
627655
628656 def _get_shard_size_mapping (self , loaded_shard_id : str , head_dim : int ):
@@ -641,7 +669,6 @@ def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = N
641669 weight_need_transpose = getattr (param , "weight_need_transpose" , False )
642670 if loaded_shard_id is None :
643671 if weight_need_transpose :
644- loaded_weight = get_tensor (loaded_weight )
645672 loaded_weight = loaded_weight .transpose ([1 , 0 ])
646673 # Avoid redundant transpose of fused weights when weight_loader is called iteratively
647674 param .weight_need_transpose = False
@@ -661,7 +688,6 @@ def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = N
661688 # split q k v
662689 assert loaded_shard_id in ["q" , "k" , "v" ]
663690 if weight_need_transpose :
664- loaded_weight = get_tensor (loaded_weight )
665691 loaded_weight = loaded_weight .transpose ([1 , 0 ])
666692 # Tensor parallelism splits the weight along the output_dim
667693 if self .nranks != 1 :
@@ -671,8 +697,6 @@ def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = N
671697 shard_size = block_size
672698 loaded_weight = slice_fn (loaded_weight , output_dim , start = shard_offset , end = shard_offset + shard_size )
673699
674- loaded_weight = get_tensor (loaded_weight )
675-
676700 if not param ._is_initialized ():
677701 param .initialize ()
678702
@@ -798,7 +822,7 @@ def __init__(
798822 add_bias : bool = False ,
799823 reduce_results : bool = True ,
800824 skip_quant : bool = False ,
801- weight_dtype = "" ,
825+ weight_dtype : str = "" ,
802826 ):
803827 """
804828 Initialize a linear layer with additional parameters for inference and quantization.
@@ -847,10 +871,6 @@ def __init__(
847871 ),
848872 model_format = fd_config .model_config .model_format ,
849873 )
850- if self .nranks > 0 :
851- if self .with_bias :
852- # col parallel
853- _set_var_distributed (self .bias , split_axis = 0 )
854874
855875 self .reduce_results = reduce_results
856876
0 commit comments