Skip to content

Commit

Permalink
feat: add _context_attention_kernel_with_CC in deepseek2 (#693)
Browse files Browse the repository at this point in the history
  • Loading branch information
blueswhen authored Jan 13, 2025
1 parent 2c64be6 commit 269631d
Show file tree
Hide file tree
Showing 3 changed files with 198 additions and 246 deletions.
71 changes: 70 additions & 1 deletion lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
context_attention_fwd,
context_attention_fwd_no_prompt_cache,
)
from lightllm.models.deepseek2.triton_kernel.context_flashattention_nopad_with_v import context_attention_fwd_with_v
from lightllm.models.deepseek2.triton_kernel.sample_kv import sample_kv

from lightllm.models.deepseek2.triton_kernel.gqa_flash_decoding import gqa_token_decode_attention_flash_decoding
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
Expand Down Expand Up @@ -54,6 +56,7 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]):
if mscale_all_dim:
mscale = get_deepseek_mscale(scaling_factor, mscale_all_dim)
self.softmax_scale = self.softmax_scale * mscale * mscale
self.enable_cc_method = os.getenv("ENABLE_CC_METHOD", "False").upper() in ["ON", "TRUE", "1"]
super().__init__(layer_num, tp_rank, world_size, network_config, mode)
self.enable_dp = os.getenv("ENABLE_DP", "0").upper() in ["ON", "TRUE", "1"]
if self.enable_dp:
Expand All @@ -65,7 +68,14 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]):
return

def _bind_attention(self):
self._context_attention_kernel = partial(Deepseek2TransformerLayerInfer._context_attention_kernel_origin, self)
if self.enable_cc_method:
self._context_attention_kernel = partial(
Deepseek2TransformerLayerInfer._context_attention_kernel_with_CC, self
)
else:
self._context_attention_kernel = partial(
Deepseek2TransformerLayerInfer._context_attention_kernel_origin, self
)
self._token_attention_kernel = partial(
Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashdecoding, self
)
Expand Down Expand Up @@ -123,6 +133,65 @@ def _get_o(
o_tensor = layer_weight.o_weight_.mm(input.reshape(-1, self.tp_q_head_num_ * self.qk_nope_head_dim))
return o_tensor

def _context_attention_kernel_with_CC(
self,
q: torch.Tensor,
kv,
infer_state: Deepseek2InferStateInfo,
layer_weight: Deepseek2TransformerLayerWeight,
out=None,
) -> torch.Tensor:
if infer_state.use_dynamic_prompt_cache:
kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
compressed_kv = self.alloc_tensor(
[infer_state.total_token_num, 1, layer_weight.kv_lora_rank], dtype=kv.dtype
)
k_rope = self.alloc_tensor([infer_state.total_token_num, 1, self.qk_rope_head_dim], dtype=kv.dtype)
sample_kv(
kv,
compressed_kv,
k_rope,
infer_state.b_req_idx,
infer_state.b_seq_len,
infer_state.req_manager.req_to_token_indexs,
)
else:
compressed_kv, k_rope = torch.split( # (b*s, 1, kv_lora + qk_r)
kv, [layer_weight.kv_lora_rank, layer_weight.qk_rope_head_dim], dim=-1
)

# CC
k_nope = self.alloc_tensor(
[compressed_kv.shape[0], q.shape[1], self.qk_nope_head_dim],
dtype=compressed_kv.dtype,
)
v = self.alloc_tensor(
k_nope.shape,
dtype=compressed_kv.dtype,
)
compressed_kv = compressed_kv.view(-1, layer_weight.kv_lora_rank)
wk = layer_weight.k_b_proj_.weight.view(-1, layer_weight.kv_lora_rank)
wv = layer_weight.v_b_proj_.weight.transpose(1, 2).view(-1, layer_weight.kv_lora_rank)
torch.mm(compressed_kv, wk.transpose(0, 1), out=k_nope.reshape(compressed_kv.shape[0], -1))
torch.mm(compressed_kv, wv.transpose(0, 1), out=v.reshape(compressed_kv.shape[0], -1))

q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :]
o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) if out is None else out
context_attention_fwd_with_v(
q_nope,
q_rope,
k_nope,
k_rope,
v,
o_tensor.view(-1, self.tp_q_head_num_, q_nope.shape[-1]),
infer_state.b_start_loc,
infer_state.b_seq_len,
infer_state.b_ready_cache_len,
infer_state.max_len_in_batch,
self.softmax_scale,
)
return o_tensor

def _context_attention_kernel_origin(
self,
q: torch.Tensor,
Expand Down
Loading

0 comments on commit 269631d

Please sign in to comment.