Skip to content

Commit e6693d5

Browse files
committed
Use torch compile cpy + tensorcore GEMM (use_custom_cublas_mm=False)
Signed-off-by: Chang Liu (Enterprise Products) <[email protected]>
1 parent e84d056 commit e6693d5

File tree

1 file changed

+10
-3
lines changed
  • tensorrt_llm/_torch/attention_backend/sparse

1 file changed

+10
-3
lines changed

tensorrt_llm/_torch/attention_backend/sparse/dsa.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -718,7 +718,7 @@ def __init__(self,
718718
dtype=torch.float32,
719719
quant_config=None,
720720
skip_create_weights_in_init=skip_create_weights_in_init,
721-
use_custom_cublas_mm=True)
721+
use_custom_cublas_mm=False)
722722

723723
self.rotary_emb = RotaryEmbedding(
724724
pos_embd_params.rope,
@@ -1233,10 +1233,17 @@ def sparse_attn_indexer(
12331233
dtype=torch.int32)
12341234
return topk_indices_buffer
12351235

1236+
@maybe_compile(dynamic=True)
1237+
def _scale_v2(self, hidden_states: torch.Tensor, q_scale: torch.Tensor,
1238+
s: float) -> torch.Tensor:
1239+
weights = self.weights_proj(hidden_states.float())
1240+
return weights * q_scale.squeeze(-1) * s
1241+
12361242
def weight_scale(self, hidden_states: torch.Tensor,
12371243
q_scale: torch.Tensor) -> torch.Tensor:
1238-
weights = self.weights_proj(hidden_states.float())
1239-
weights = _scale(weights, q_scale, self.weight_scale_factor)
1244+
#weights = self.weights_proj(hidden_states.float())
1245+
weights = self._scale_v2(hidden_states, q_scale,
1246+
self.weight_scale_factor)
12401247
return weights
12411248

12421249
@torch.inference_mode()

0 commit comments

Comments
 (0)