@@ -440,17 +440,23 @@ def weight_loader(self,
440
440
param .shard_weight_type [loaded_shard_id ] = loaded_weight .item ()
441
441
return
442
442
443
- if is_gguf_weight and isinstance (param , UninitializedParameter ):
444
- from gguf .constants import GGML_QUANT_SIZES
443
+ if is_gguf_weight :
444
+ tp_size = get_tensor_model_parallel_world_size ()
445
+ tp_rank = get_tensor_model_parallel_rank ()
446
+
447
+ output_dim = getattr (param , "output_dim" , None )
448
+ shard_size = loaded_weight .size (output_dim ) // tp_size
449
+ start_idx = tp_rank * shard_size
450
+
451
+ loaded_weight = loaded_weight .narrow (output_dim , start_idx ,
452
+ shard_size )
445
453
446
- ori_shape = param .tensor_shape
447
- weight_types = self .qweight_type .shard_weight_type .values ()
448
- row_size = []
449
- for weight_type in weight_types :
450
- block_size , type_size = GGML_QUANT_SIZES [weight_type ]
451
- row_size .append (ori_shape [1 ] // block_size * type_size )
452
- q_shape = (ori_shape [0 ], max (row_size ))
453
- param .materialize (q_shape , dtype = loaded_weight .dtype )
454
+ param .shard_id .append (loaded_shard_id )
455
+ param .shard_id_map [loaded_shard_id ] = len (param .data_container )
456
+ param .data_container .append (loaded_weight )
457
+ if len (param .data_container ) == 2 :
458
+ self .qweight = param .materialize_nested ()
459
+ return
454
460
455
461
param_data = param .data
456
462
output_dim = getattr (param , "output_dim" , None )
@@ -515,18 +521,6 @@ def weight_loader(self,
515
521
shard_offset = loaded_weight .shape [output_dim ] * \
516
522
loaded_shard_id
517
523
518
- if is_gguf_weight :
519
- tp_size = get_tensor_model_parallel_world_size ()
520
- output_dim = getattr (param , "output_dim" , None )
521
- shard_shape = list (loaded_weight .shape )
522
- shard_shape [output_dim ] = shard_shape [output_dim ] // tp_size
523
- param .shard_id .append (loaded_shard_id )
524
- param .shard_size [loaded_shard_id ] = shard_shape
525
-
526
- input_dim = getattr (param , "input_dim" , None )
527
- input_size = loaded_weight .shape [input_dim ]
528
- param_data = param_data .narrow (input_dim , 0 , input_size )
529
-
530
524
param_data = param_data .narrow (output_dim , shard_offset ,
531
525
shard_size )
532
526
start_idx = tp_rank * shard_size
@@ -783,17 +777,23 @@ def weight_loader(self,
783
777
param .shard_weight_type [loaded_shard_id ] = loaded_weight .item ()
784
778
return
785
779
786
- if is_gguf_weight and isinstance (param , UninitializedParameter ):
787
- from gguf .constants import GGML_QUANT_SIZES
780
+ if is_gguf_weight :
781
+ tp_size = get_tensor_model_parallel_world_size ()
782
+ tp_rank = get_tensor_model_parallel_rank ()
788
783
789
- ori_shape = param .tensor_shape
790
- weight_types = self .qweight_type .shard_weight_type .values ()
791
- row_size = []
792
- for weight_type in weight_types :
793
- block_size , type_size = GGML_QUANT_SIZES [weight_type ]
794
- row_size .append (ori_shape [1 ] // block_size * type_size )
795
- q_shape = (ori_shape [0 ], max (row_size ))
796
- param .materialize (q_shape , dtype = loaded_weight .dtype )
784
+ output_dim = getattr (param , "output_dim" , None )
785
+ shard_size = loaded_weight .size (output_dim ) // tp_size
786
+ start_idx = tp_rank * shard_size
787
+
788
+ loaded_weight = loaded_weight .narrow (output_dim , start_idx ,
789
+ shard_size )
790
+
791
+ param .shard_id .append (loaded_shard_id )
792
+ param .shard_id_map [loaded_shard_id ] = len (param .data_container )
793
+ param .data_container .append (loaded_weight )
794
+ if len (param .data_container ) == 3 :
795
+ self .qweight = param .materialize_nested ()
796
+ return
797
797
798
798
param_data = param .data
799
799
output_dim = getattr (param , "output_dim" , None )
@@ -883,18 +883,6 @@ def weight_loader(self,
883
883
shard_size , shard_offset = adjust_bitsandbytes_4bit_shard (
884
884
param , orig_qkv_offsets , loaded_shard_id )
885
885
886
- if is_gguf_weight :
887
- tp_size = get_tensor_model_parallel_world_size ()
888
- output_dim = getattr (param , "output_dim" , None )
889
- shard_shape = list (loaded_weight .shape )
890
- shard_shape [output_dim ] = shard_shape [output_dim ] // tp_size
891
- param .shard_id .append (loaded_shard_id )
892
- param .shard_size [loaded_shard_id ] = shard_shape
893
-
894
- input_dim = getattr (param , "input_dim" , None )
895
- input_size = loaded_weight .shape [input_dim ]
896
- param_data = param_data .narrow (input_dim , 0 , input_size )
897
-
898
886
param_data = param_data .narrow (output_dim , shard_offset ,
899
887
shard_size )
900
888
if loaded_shard_id == "q" :
0 commit comments