@@ -950,10 +950,12 @@ def symmetric_quantize_last_axis_of_batched_matrix(weight, quant_mode):
950950 return qweight , scale
951951
952952
953- def preprocess_weights_for_mixed_gemm (tensor : torch .Tensor ,
954- quant_mode : torch .dtype ,
955- act_dtype : torch .dtype ,
956- sm_ : int = - 1 ) -> torch .Tensor :
953+ def preprocess_weights_for_mixed_gemm (
954+ tensor : torch .Tensor ,
955+ quant_mode : torch .dtype ,
956+ act_dtype : torch .dtype ,
957+ sm_ : int = - 1 ,
958+ do_weight_interleave : bool = True ) -> torch .Tensor :
957959 sm_ = sm_ if sm_ > 0 else get_sm_version ()
958960 if len (tensor .shape ) == 2 :
959961 tensor = tensor .unsqueeze (0 )
@@ -988,13 +990,12 @@ def preprocess_weights_for_mixed_gemm(tensor: torch.Tensor,
988990 assert (num_rows % B_ROWS_PER_MMA == 0 )
989991 assert (num_cols % MMA_SHAPE_N == 0 )
990992
991- row_idx_list = [
992- (row_idx // B_ROWS_PER_MMA ) * B_ROWS_PER_MMA +
993- permutation_map [f"{ BITS_PER_ELT_A } _{ BITS_PER_ELT_B } " ][row_idx %
994- B_ROWS_PER_MMA ]
995- for row_idx in range (num_rows )
996- ]
997- tensor = tensor [:, row_idx_list , :]
993+ if do_weight_interleave :
994+ row_idx_list = [(row_idx // B_ROWS_PER_MMA ) * B_ROWS_PER_MMA +
995+ permutation_map [f"{ BITS_PER_ELT_A } _{ BITS_PER_ELT_B } " ][
996+ row_idx % B_ROWS_PER_MMA ]
997+ for row_idx in range (num_rows )]
998+ tensor = tensor [:, row_idx_list , :]
998999
9991000 # subbyte_transpose
10001001 original_shape = tensor .shape
@@ -1010,42 +1011,63 @@ def preprocess_weights_for_mixed_gemm(tensor: torch.Tensor,
10101011 else :
10111012 tensor = tensor .permute (0 , 2 , 1 ).reshape (original_shape )
10121013
1013- # interleave_column_major_tensor
1014- interleave = BITS_PER_ELT_A // BITS_PER_ELT_B
1015- if interleave > 1 and sm_ < 90 :
1016- rows_per_tile = 128 * 8 // BITS_PER_ELT_A
1017- elts_in_int32 = 32 // BITS_PER_ELT_B
1018-
1019- assert (num_rows % elts_in_int32 == 0 )
1020- assert (num_rows % rows_per_tile == 0 )
1021-
1022- tensor = tensor .reshape (num_experts , - 1 , interleave ,
1023- num_rows // rows_per_tile ,
1024- rows_per_tile * 4 // elts_in_int32 )
1025- tensor = tensor .permute (0 , 1 , 3 , 2 , 4 ).reshape (original_shape )
1026-
1027- # add_bias_and_interleave_quantized_tensor_inplace
1028- if BITS_PER_ELT_B == 8 :
1029- tensor += - 256 * (tensor > 127 ).byte () + 128
1030- tensor = tensor .reshape (- 1 , 4 )[:, [0 , 2 , 1 , 3 ]].reshape (tensor .shape )
1031- elif BITS_PER_ELT_B == 4 :
1032- tensor = tensor .view (torch .uint8 )
1033- high_tensor = (tensor >> 4 ).unsqueeze (- 1 )
1034- low_tensor = ((tensor << 4 ) >> 4 ).unsqueeze (- 1 )
1035- new_tensor = torch .cat ([low_tensor , high_tensor ],
1036- dim = - 1 ).reshape (tensor .shape [0 ], tensor .shape [1 ],
1037- - 1 )
1038- new_tensor = new_tensor .reshape (
1039- - 1 , 8 )[:, [0 , 2 , 4 , 6 , 1 , 3 , 5 , 7 ]].reshape (new_tensor .shape )
1040- new_tensor += - 16 * (new_tensor > 7 ).byte () + 8
1041- new_tensor = new_tensor [:, :, 0 ::2 ] + new_tensor [:, :, 1 ::2 ] * 16
1042- tensor = new_tensor .view (torch .int8 )
1043- else :
1044- raise NotImplementedError
1014+ if do_weight_interleave :
1015+ # interleave_column_major_tensor
1016+ interleave = BITS_PER_ELT_A // BITS_PER_ELT_B
1017+ if interleave > 1 and sm_ < 90 :
1018+ rows_per_tile = 128 * 8 // BITS_PER_ELT_A
1019+ elts_in_int32 = 32 // BITS_PER_ELT_B
1020+
1021+ assert (num_rows % elts_in_int32 == 0 )
1022+ assert (num_rows % rows_per_tile == 0 )
1023+
1024+ tensor = tensor .reshape (num_experts , - 1 , interleave ,
1025+ num_rows // rows_per_tile ,
1026+ rows_per_tile * 4 // elts_in_int32 )
1027+ tensor = tensor .permute (0 , 1 , 3 , 2 , 4 ).reshape (original_shape )
1028+
1029+ # add_bias_and_interleave_quantized_tensor_inplace
1030+ if BITS_PER_ELT_B == 8 :
1031+ tensor += - 256 * (tensor > 127 ).byte () + 128
1032+ tensor = tensor .reshape (- 1 , 4 )[:,
1033+ [0 , 2 , 1 , 3 ]].reshape (tensor .shape )
1034+ elif BITS_PER_ELT_B == 4 :
1035+ tensor = tensor .view (torch .uint8 )
1036+ high_tensor = (tensor >> 4 ).unsqueeze (- 1 )
1037+ low_tensor = ((tensor << 4 ) >> 4 ).unsqueeze (- 1 )
1038+ new_tensor = torch .cat ([low_tensor , high_tensor ],
1039+ dim = - 1 ).reshape (tensor .shape [0 ],
1040+ tensor .shape [1 ], - 1 )
1041+ new_tensor = new_tensor .reshape (
1042+ - 1 , 8 )[:, [0 , 2 , 4 , 6 , 1 , 3 , 5 , 7 ]].reshape (new_tensor .shape )
1043+ new_tensor += - 16 * (new_tensor > 7 ).byte () + 8
1044+ new_tensor = new_tensor [:, :, 0 ::2 ] + new_tensor [:, :, 1 ::2 ] * 16
1045+ tensor = new_tensor .view (torch .int8 )
1046+ else :
1047+ raise NotImplementedError
10451048
10461049 return tensor .squeeze (0 ).contiguous ()
10471050
10481051
1052+ def get_weight_scale_interleave_factor (interleaved_dim : int ,
1053+ group_size : int = 128 ) -> int :
1054+ # Calculate the weight_scale interleave factor for W4A8 groupwise MoE quant
1055+ # only Hopper w4a8 does interleave for weight scale, other arch or Hopper w4a16 default to 1
1056+ factor = 1
1057+ if get_sm_version () == 90 :
1058+ if interleaved_dim % (4 * group_size ) == 0 :
1059+ factor = 4
1060+ elif interleaved_dim % (2 * group_size ) == 0 :
1061+ factor = 2
1062+ elif interleaved_dim % group_size == 0 :
1063+ factor = 1
1064+ else :
1065+ raise NotImplementedError (
1066+ f"Interleaved dimension must be a multiple of group_size ({ group_size } ), received { interleaved_dim } ."
1067+ )
1068+ return factor
1069+
1070+
10491071def validate_group_size (layer ):
10501072 # TODO: Remove this function and its usage after W4A8-AWQ with group_size = 64 is implemented.
10511073 W4A8_AWQ = 8
0 commit comments