@@ -434,15 +434,15 @@ def flash_attention(
434434 query(Tensor): The query tensor in the Attention module.
435435 4-D tensor with shape:
436436 [batch_size, seq_len, num_heads, head_dim].
437- The dtype can be float61 or bfloat16.
437+ The dtype can be float16 or bfloat16.
438438 key(Tensor): The key tensor in the Attention module.
439439 4-D tensor with shape:
440440 [batch_size, seq_len, num_heads, head_dim].
441- The dtype can be float61 or bfloat16.
441+ The dtype can be float16 or bfloat16.
442442 value(Tensor): The value tensor in the Attention module.
443443 4-D tensor with shape:
444444 [batch_size, seq_len, num_heads, head_dim].
445- The dtype can be float61 or bfloat16.
445+ The dtype can be float16 or bfloat16.
446446 dropout(float): The dropout ratio.
447447 causal(bool): Whether enable causal mode.
448448 return_softmax(bool): Whether to return softmax.
@@ -623,6 +623,157 @@ def flash_attention(
623623 )
624624
625625
626+ @overload
627+ def flash_attention_v3_varlen (
628+ query : Tensor ,
629+ key : Tensor ,
630+ value : Tensor ,
631+ cu_seqlens_q : Tensor ,
632+ cu_seqlens_k : Tensor ,
633+ dropout : float = ...,
634+ causal : bool = ...,
635+ return_softmax : Literal [False ] = ...,
636+ * ,
637+ fixed_seed_offset : Tensor | None = ...,
638+ rng_name : str = ...,
639+ training : bool = ...,
640+ softmax_scale : float | None = ...,
641+ max_seqlen_q : int = ...,
642+ max_seqlen_k : int = ...,
643+ name : str | None = ...,
644+ ) -> tuple [Tensor , None ]: ...
645+
646+
647+ @overload
648+ def flash_attention_v3_varlen (
649+ query : Tensor ,
650+ key : Tensor ,
651+ value : Tensor ,
652+ cu_seqlens_q : Tensor ,
653+ cu_seqlens_k : Tensor ,
654+ dropout : float = ...,
655+ causal : bool = ...,
656+ return_softmax : Literal [True ] = ...,
657+ * ,
658+ fixed_seed_offset : Tensor | None = ...,
659+ rng_name : str = ...,
660+ training : bool = ...,
661+ softmax_scale : float | None = ...,
662+ max_seqlen_q : int = ...,
663+ max_seqlen_k : int = ...,
664+ name : str | None = ...,
665+ ) -> tuple [Tensor , Tensor ]: ...
666+
667+
668+ def flash_attention_v3_varlen (
669+ query ,
670+ key ,
671+ value ,
672+ cu_seqlens_q ,
673+ cu_seqlens_k ,
674+ dropout = 0.0 ,
675+ causal = False ,
676+ return_softmax = False ,
677+ * ,
678+ fixed_seed_offset = None ,
679+ rng_name = "" ,
680+ training = True ,
681+ softmax_scale = None ,
682+ max_seqlen_q = 0 ,
683+ max_seqlen_k = 0 ,
684+ name = None ,
685+ ):
686+ r"""
687+ The equation is:
688+
689+ .. math::
690+
691+ result=softmax(\frac{ Q * K^T }{\sqrt{d}}) * V
692+
693+ where : ``Q``, ``K``, and ``V`` represent the three input parameters of the attention module.
694+ The dimensions of the three parameters are the same.
695+ ``d`` represents the size of the last dimension of the three parameters.
696+ This is the varlen version of flash attention.
697+
698+ Warning:
699+ This API is only support inputs with dtype float16 and bfloat16.
700+
701+ Args:
702+ query(Tensor): The query tensor in the Attention module.
703+ 3-D tensor with shape:
704+ [token_num, num_heads, head_dim].
705+ The dtype can be float16 or bfloat16.
706+ key(Tensor): The key tensor in the Attention module.
707+ 3-D tensor with shape:
708+ [token_num, num_heads, head_dim].
709+ The dtype can be float16 or bfloat16.
710+ value(Tensor): The value tensor in the Attention module.
711+ 3-D tensor with shape:
712+ [token_num, num_heads, head_dim].
713+ The dtype can be float16 or bfloat16.
714+ cu_seqlens_q(Tensor): The cumulative sequence lengths of the sequences in the batch,
715+ used to index query.
716+ cu_seqlens_k(Tensor): The cumulative sequence lengths of the sequences in the batch,
717+ used to index key and value.
718+ dropout(float): The dropout ratio.
719+ causal(bool): Whether enable causal mode.
720+ return_softmax(bool): Whether to return softmax.
721+ fixed_seed_offset(Tensor|None, optional): With fixed seed, offset for dropout mask.
722+ rng_name(str): The name to select Generator.
723+ training(bool): Whether it is in the training phase.
724+ softmax_scale(float): The softmax scale of the attention.
725+ max_seqlen_q(int): Maximum sequence length of query in the batch. Note it's the padding length, not the max actual seqlen.
726+ max_seqlen_k(int): Maximum sequence length of key/value in the batch.
727+ name(str|None, optional): The default value is None. Normally there is no need for user
728+ to set this property. For more information, please refer to
729+ :ref:`api_guide_Name`.
730+
731+ Returns:
732+ out(Tensor): The attention tensor. 3-D tensor with shape: [token_num, num_heads, head_dim]. The dtype can be float16 or bfloat16.
733+ softmax(Tensor): The softmax tensor. None if return_softmax is False.
734+
735+ Examples:
736+ .. code-block:: python
737+
738+ >>> # doctest: +SKIP('flash_attn_v3 need H100 compile')
739+ >>> import paddle
740+
741+ >>> paddle.seed(2023)
742+ >>> q = paddle.rand((10, 2, 128), dtype="bfloat16")
743+ >>> cu_seqlens_q = paddle.to_tensor([0, 10], dtype="int32")
744+ >>> max_seq_len_q = 10
745+
746+ >>> output = paddle.nn.functional.flash_attention.flash_attention_v3_varlen(q, q, q, cu_seqlens_q, cu_seqlens_q, max_seqlen_q=max_seq_len_q, max_seqlen_k=max_seq_len_q, causal=True)
747+ >>> # doctest: -SKIP
748+
749+ """
750+ if softmax_scale is None :
751+ softmax_scale = query .shape [- 1 ] ** (- 0.5 )
752+ out , softmax_lse = _C_ops .flash_attn_v3_varlen (
753+ query ,
754+ key ,
755+ value ,
756+ cu_seqlens_q ,
757+ cu_seqlens_k ,
758+ None , # q_v_
759+ None , # q_descale_
760+ None , # k_descale_
761+ None , # v_descale_
762+ softmax_scale ,
763+ causal ,
764+ - 1 , # window_size_left
765+ - 1 , # window_size_right
766+ 0.0 , # softcap
767+ 1 , # num_splits
768+ False , # manual_set_pack_gqa
769+ False , # pack_gqa_
770+ 0 , # sm_margin,
771+ max_seqlen_q ,
772+ max_seqlen_k ,
773+ )
774+ return out , softmax_lse # return_softmax
775+
776+
626777@overload
627778def flash_attn_qkvpacked (
628779 qkv : Tensor ,
@@ -912,15 +1063,15 @@ def flash_attn_unpadded(
9121063 query(Tensor): The query tensor in the Attention module.
9131064 3-D tensor with shape:
9141065 [total_seq_len, num_heads, head_dim].
915- The dtype can be float61 or bfloat16.
1066+ The dtype can be float16 or bfloat16.
9161067 key(Tensor): The key tensor in the Attention module.
9171068 3-D tensor with shape:
9181069 [total_seq_len, num_heads, head_dim].
919- The dtype can be float61 or bfloat16.
1070+ The dtype can be float16 or bfloat16.
9201071 value(Tensor): The value tensor in the Attention module.
9211072 3-D tensor with shape:
9221073 [total_seq_len, num_heads, head_dim].
923- The dtype can be float61 or bfloat16.
1074+ The dtype can be float16 or bfloat16.
9241075 cu_seqlens_q(Tensor): The cumulative sequence lengths of the sequences in the batch,
9251076 used to index query.
9261077 cu_seqlens_k(Tensor): The cumulative sequence lengths of the sequences in the batch,
0 commit comments