Skip to content

Commit b869ad3

Browse files
authored
[Ascend] Wx/fix_varlen_flash_attention_on_ascend (#1158)
* fix varlen flash attention bug
1 parent 132a28c commit b869ad3

File tree

5 files changed

+64
-10
lines changed

5 files changed

+64
-10
lines changed

diopi_test/python/configs/diopi_configs.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -8985,7 +8985,8 @@
89858985
p_dropout=[0, 0, 0, 0],
89868986
is_causal=[True, True, False, True],
89878987
softmax_scale=[None, 0.0883, None, 0.125],
8988-
max_seqlen=[32, 32, 128, 64],
8988+
max_seqlen_q=[32, 32, 128, 64],
8989+
max_seqlen_kv=[32, 32, 128, 64],
89898990
cu_seqlens_q=[[0, 32], [0, 16, 48, 64], [0, 32, 64, 128, 256], [0, 16, 48, 64, 128]],
89908991
cu_seqlens_kv=[[0, 32], [0, 16, 48, 64], [0, 32, 64, 128, 256], [0, 16, 48, 64, 128]],
89918992
),

diopi_test/python/conformance/customized_test.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -477,10 +477,11 @@ def flash_attention_v3(q, k, v, p_dropout, softmax_scale, is_causal):
477477
return output
478478

479479
def flash_attention_varlen(
480-
q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen, p_dropout, softmax_scale, is_causal
480+
q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, p_dropout, softmax_scale, is_causal
481481
):
482482
# Currently, only equality between cu_seqlens_q and cu_seqlens_kv is supported here
483483
cu_seqlens = cu_seqlens_q
484+
max_seqlen = max_seqlen_q
484485
# In order to compare the accuracy with the baseline value, dropout is not used during testing.
485486
batch_size = len(cu_seqlens) - 1
486487
_, head_num, head_dim = q.size()

diopi_test/python/conformance/diopi_functions.py

+49-1
Original file line numberDiff line numberDiff line change
@@ -5457,7 +5457,7 @@ def flash_attention_v3_backward(q, k, v, out, grad_outputs, p_dropout, softmax_s
54575457
check_returncode(ret)
54585458
return {'q': grad_q, 'k': grad_k, 'v': grad_v}
54595459

5460-
def flash_attention_varlen(q, k, v, max_seqlen, cu_seqlens_q, cu_seqlens_kv, p_dropout, softmax_scale, is_causal):
5460+
def flash_attention_varlen(q, k, v, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, p_dropout, softmax_scale, is_causal):
54615461
call = "diopiFlashAttentionVarLen"
54625462
func = check_function(call)
54635463
q_size = list(q.size().data)
@@ -5500,11 +5500,59 @@ def flash_attention_varlen(q, k, v, max_seqlen, cu_seqlens_q, cu_seqlens_kv, p_d
55005500
v,
55015501
cu_seqlens_q,
55025502
cu_seqlens_kv,
5503+
max_seqlen_q,
5504+
max_seqlen_kv,
55035505
p_dropout,
55045506
softmax_scale,
55055507
is_causal,
55065508
)
55075509
check_returncode(ret)
5510+
GLOBAL_STATE["flash_attention_varlen_attention_mask"] = attention_mask
5511+
GLOBAL_STATE["flash_attention_varlen_dropout_mask"] = dropout_mask
5512+
GLOBAL_STATE["flash_attention_varlen_softmax_max"] = softmax_max
5513+
GLOBAL_STATE["flash_attention_varlen_softmax_sum"] = softmax_sum
5514+
GLOBAL_STATE["flash_attention_varlen_softmax_out"] = softmax_out
5515+
return out
5516+
5517+
def flash_attention_varlen_backward(q, k, v, out, grad_outputs, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, p_dropout, softmax_scale, is_causal):
5518+
call = "diopiFlashAttentionVarLenBackward"
5519+
func = check_function(call)
5520+
assert p_dropout >=0 and p_dropout <=1, "The p_dropout value must be in range of [0, 1]"
5521+
head_dim = q.shape().data[-1]
5522+
softmax_scale = 1.0 / math.sqrt(head_dim) if not softmax_scale else softmax_scale
5523+
cu_seqlens_q = Sizes(cu_seqlens_q[1:])
5524+
cu_seqlens_kv = Sizes(cu_seqlens_kv[1:])
5525+
grad_q = raw_like(q)
5526+
grad_k = raw_like(k)
5527+
grad_v = raw_like(v)
5528+
attention_mask = GLOBAL_STATE.pop("flash_attention_varlen_attention_mask")
5529+
dropout_mask = GLOBAL_STATE.pop("flash_attention_varlen_dropout_mask")
5530+
softmax_max = GLOBAL_STATE.pop("flash_attention_varlen_softmax_max")
5531+
softmax_sum = GLOBAL_STATE.pop("flash_attention_varlen_softmax_sum")
5532+
softmax_out = GLOBAL_STATE.pop("flash_attention_varlen_softmax_out")
5533+
ret = func(
5534+
q.context(),
5535+
grad_q,
5536+
grad_k,
5537+
grad_v,
5538+
grad_outputs[0],
5539+
q,
5540+
k,
5541+
v,
5542+
cu_seqlens_q,
5543+
cu_seqlens_kv,
5544+
out,
5545+
attention_mask,
5546+
dropout_mask,
5547+
softmax_max,
5548+
softmax_sum,
5549+
softmax_out,
5550+
max_seqlen_q,
5551+
max_seqlen_kv,
5552+
p_dropout,
5553+
softmax_scale,
5554+
)
5555+
check_returncode(ret)
55085556
return out
55095557

55105558
def scaled_masked_softmax(input, mask, scale, fixed_triu_mask):

impl/ascend_npu/diopi_impl/functions_ext/flash_attention_varlen.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ const int64_t uInt8BitNumber = 8;
2222
diopiError_t diopiFlashAttentionVarLen(diopiContextHandle_t ctx, diopiTensorHandle_t attentionOut, diopiTensorHandle_t* attentionMask,
2323
diopiTensorHandle_t* dropoutMask, diopiTensorHandle_t* softmaxMax, diopiTensorHandle_t* softmaxSum,
2424
diopiTensorHandle_t* softmaxOut, diopiGeneratorHandle_t gen, diopiConstTensorHandle_t q, diopiConstTensorHandle_t k,
25-
diopiConstTensorHandle_t v, diopiSize_t cumSeqQ, diopiSize_t cumSeqKV, double pDropout, double softmaxScale,
26-
bool isCausal) {
25+
diopiConstTensorHandle_t v, diopiSize_t cumSeqQ, diopiSize_t cumSeqKV, int64_t maxSeqLenQ, int64_t maxSeqLenKV,
26+
double pDropout, double softmaxScale, bool isCausal) {
2727
BEGIN_CALL_ACL_OP(q, k, v, cumSeqQ, cumSeqKV, gen, attentionOut);
2828

2929
DIOPI_CHECK(qAt.dim() == 3, "The shapes of the input query should be 3-dimensional");
@@ -68,7 +68,7 @@ diopiError_t diopiFlashAttentionVarLen(diopiContextHandle_t ctx, diopiTensorHand
6868

6969
at::Tensor attentionMaskAt = at::Tensor();
7070
if (isCausal) {
71-
attentionMaskAt = npu_preparation::apply_tensor_without_format({t, t}, qAt.options().dtype(at::kBool));
71+
attentionMaskAt = npu_preparation::apply_tensor_without_format({maxSeqLenQ, maxSeqLenKV}, qAt.options().dtype(at::kBool));
7272
EXEC_NPU_CMD(aclnnInplaceOne, attentionMaskAt);
7373
int64_t diagonal = 1;
7474
EXEC_NPU_CMD(aclnnInplaceTriu, attentionMaskAt, diagonal);
@@ -131,7 +131,7 @@ diopiError_t diopiFlashAttentionVarLenBackward(diopiContextHandle_t ctx, diopiTe
131131
diopiConstTensorHandle_t v, diopiSize_t cumSeqQ, diopiSize_t cumSeqKV, diopiConstTensorHandle_t attentionOut,
132132
diopiConstTensorHandle_t attentionMask, diopiConstTensorHandle_t dropoutMask,
133133
diopiConstTensorHandle_t softmaxMax, diopiConstTensorHandle_t softmaxSum, diopiConstTensorHandle_t softmaxOut,
134-
double pDropout, double softmaxScale) {
134+
int64_t maxSeqLenQ, int64_t maxSeqLenKV, double pDropout, double softmaxScale) {
135135
BEGIN_CALL_ACL_OP(q, k, v, cumSeqQ, cumSeqKV, attentionOut, softmaxMax, softmaxSum, softmaxOut, gradQ, gradK, gradV, gradOut);
136136

137137
at::Tensor dropoutMaskAt;

proto/include/diopi/functions_ext.h

+7-3
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,8 @@ DIOPI_API diopiError_t diopiFlashAttentionV3Backward(diopiContextHandle_t ctx, d
328328
* @param[in] v Value tensor. shape = [total_v, head_num, head_dim, where total_v = total number of value tokens in the batch. type = [bfloat16, float16].
329329
* @param[in] cum_seq_q The cumulative sequence lengths of the sequences in the batch for query. shape = [batch_size+1].
330330
* @param[in] cum_seq_kv The cumulative sequence lengths of the sequences in the batch for key and value. shape = [batch_size+1].
331+
* @param[in] max_seqlen_q Maximum sequence length for query.
332+
* @param[in] max_seqlen_kv Maximum sequence length for key and value.
331333
* @param[in] p_dropout Dropout probability.
332334
* @param[in] softmax_scale The scaling of qk^T before applying softmax. By default, softmax\_scale=\frac{1}{\sqrt{d_k}}
333335
* @param[in] is_causal Whether to apply causal attention mask.
@@ -343,7 +345,7 @@ DIOPI_API diopiError_t diopiFlashAttentionVarLen(diopiContextHandle_t ctx, diopi
343345
diopiTensorHandle_t* dropout_mask, diopiTensorHandle_t* softmax_max, diopiTensorHandle_t* softmax_sum,
344346
diopiTensorHandle_t* softmax_out, diopiGeneratorHandle_t gen, diopiConstTensorHandle_t q,
345347
diopiConstTensorHandle_t k, diopiConstTensorHandle_t v, diopiSize_t cum_seq_q, diopiSize_t cum_seq_kv,
346-
double p_dropout, double softmax_scale, bool is_causal);
348+
int64_t max_seqlen_q, int64_t max_seqlen_kv, double p_dropout, double softmax_scale, bool is_causal);
347349

348350
/**
349351
* @brief Compute the backward pass for the variable length version of Flash Attention.
@@ -360,6 +362,8 @@ DIOPI_API diopiError_t diopiFlashAttentionVarLen(diopiContextHandle_t ctx, diopi
360362
* @param[in] softmax_max Tensor representing the intermediate calculation result of softmax op from the forward pass. type = [float32].
361363
* @param[in] softmax_sum Tensor representing the intermediate calculation result of softmax op from the forward pass. type = [float32].
362364
* @param[in] softmax_out Tensor representing the intermediate calculation result of softmax op from the forward pass. type =[float32].
365+
* @param[in] max_seqlen_q Maximum sequence length for query.
366+
* @param[in] max_seqlen_kv Maximum sequence length for key and value.
363367
* @param[in] p_dropout Dropout probability.
364368
* @param[in] softmax_scale The scaling of qk^T before applying softmax. By default, softmax\_scale=\frac{1}{\sqrt{d_k}}
365369
* @param[out] grad_q The gradient of the query tensor. shape = [total_q, head_num, head_dim], where total_q = total number of query tokens in the batch. type =
@@ -374,8 +378,8 @@ DIOPI_API diopiError_t diopiFlashAttentionVarLenBackward(diopiContextHandle_t ct
374378
diopiConstTensorHandle_t k, diopiConstTensorHandle_t v, diopiSize_t cum_seq_q, diopiSize_t cum_seq_kv,
375379
diopiConstTensorHandle_t attention_out, diopiConstTensorHandle_t attention_mask,
376380
diopiConstTensorHandle_t dropout_mask, diopiConstTensorHandle_t softmax_max,
377-
diopiConstTensorHandle_t softmax_sum, diopiConstTensorHandle_t softmax_out, double p_dropout,
378-
double softmax_scale);
381+
diopiConstTensorHandle_t softmax_sum, diopiConstTensorHandle_t softmax_out, int64_t max_seqlen_q,
382+
int64_t max_seqlen_kv, double p_dropout, double softmax_scale);
379383

380384
// This interface is temporarily designed for ascend, please do not use it with other devices.
381385
/**

0 commit comments

Comments
 (0)