@@ -559,7 +559,7 @@ def create_weights(self, module: Linear, in_features: int,
559559 dtype = torch .float8_e4m3fn ),
560560 requires_grad = False )
561561
562- if get_sm_version () == 100 :
562+ if get_sm_version () == 100 and not module . use_cute_dsl_blockscaling_mm :
563563 scale_shape = (math .ceil (in_features / 512 ),
564564 math .ceil (out_features ))
565565 module .weight_scale = Parameter (torch .empty (scale_shape ,
@@ -595,6 +595,7 @@ def apply(self, module: Linear, input: torch.Tensor,
595595 # TODO (@lmin): replace with cute_dsl gemm
596596 act_input_fp8 , act_input_sf = torch .ops .trtllm .fp8_quantize_1x128 (
597597 input )
598+ print (module .weight_scale .dtype )
598599 output = torch .ops .trtllm .fp8_block_scaling_gemm (
599600 act_input_fp8 , module .weight , act_input_sf ,
600601 module .weight_scale )
@@ -649,7 +650,7 @@ def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None:
649650 weight_scale = load_weight_shard (weights [0 ][scale_name ], module .tp_size ,
650651 module .tp_rank ,
651652 module .tp_mode ).squeeze ()
652- if get_sm_version () == 100 :
653+ if get_sm_version () == 100 and not module . use_cute_dsl_blockscaling_mm :
653654 weight_scale = fp8_utils .transform_sf_into_required_layout (
654655 weight_scale ,
655656 mn = module .weight .shape [0 ],
@@ -692,7 +693,7 @@ def load_weights_fused_gate_up_linear(self, module: Linear,
692693 module .tp_rank , module .tp_mode )
693694 fused_scale = torch .cat ([left_scale , right_scale ], dim = 0 ).squeeze ()
694695 copy_weight (module .weight , fused_weight )
695- if get_sm_version () == 100 :
696+ if get_sm_version () == 100 and not module . use_cute_dsl_blockscaling_mm :
696697 fused_scale = fp8_utils .transform_sf_into_required_layout (
697698 fused_scale ,
698699 mn = fused_weight .shape [0 ],
0 commit comments