File tree Expand file tree Collapse file tree 1 file changed +10
-3
lines changed
tensorrt_llm/_torch/attention_backend/sparse Expand file tree Collapse file tree 1 file changed +10
-3
lines changed Original file line number Diff line number Diff line change @@ -718,7 +718,7 @@ def __init__(self,
718718 dtype = torch .float32 ,
719719 quant_config = None ,
720720 skip_create_weights_in_init = skip_create_weights_in_init ,
721- use_custom_cublas_mm = True )
721+ use_custom_cublas_mm = False )
722722
723723 self .rotary_emb = RotaryEmbedding (
724724 pos_embd_params .rope ,
@@ -1233,10 +1233,17 @@ def sparse_attn_indexer(
12331233 dtype = torch .int32 )
12341234 return topk_indices_buffer
12351235
1236+ @maybe_compile (dynamic = True )
1237+ def _scale_v2 (self , hidden_states : torch .Tensor , q_scale : torch .Tensor ,
1238+ s : float ) -> torch .Tensor :
1239+ weights = self .weights_proj (hidden_states .float ())
1240+ return weights * q_scale .squeeze (- 1 ) * s
1241+
12361242 def weight_scale (self , hidden_states : torch .Tensor ,
12371243 q_scale : torch .Tensor ) -> torch .Tensor :
1238- weights = self .weights_proj (hidden_states .float ())
1239- weights = _scale (weights , q_scale , self .weight_scale_factor )
1244+ #weights = self.weights_proj(hidden_states.float())
1245+ weights = self ._scale_v2 (hidden_states , q_scale ,
1246+ self .weight_scale_factor )
12401247 return weights
12411248
12421249 @torch .inference_mode ()
You can’t perform that action at this time.
0 commit comments