You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: proto/include/diopi/functions_ext.h
+7-3
Original file line number
Diff line number
Diff line change
@@ -328,6 +328,8 @@ DIOPI_API diopiError_t diopiFlashAttentionV3Backward(diopiContextHandle_t ctx, d
328
328
* @param[in] v Value tensor. shape = [total_v, head_num, head_dim, where total_v = total number of value tokens in the batch. type = [bfloat16, float16].
329
329
* @param[in] cum_seq_q The cumulative sequence lengths of the sequences in the batch for query. shape = [batch_size+1].
330
330
* @param[in] cum_seq_kv The cumulative sequence lengths of the sequences in the batch for key and value. shape = [batch_size+1].
331
+
* @param[in] max_seqlen_q Maximum sequence length for query.
332
+
* @param[in] max_seqlen_kv Maximum sequence length for key and value.
331
333
* @param[in] p_dropout Dropout probability.
332
334
* @param[in] softmax_scale The scaling of qk^T before applying softmax. By default, softmax\_scale=\frac{1}{\sqrt{d_k}}
333
335
* @param[in] is_causal Whether to apply causal attention mask.
* @param[in] softmax_max Tensor representing the intermediate calculation result of softmax op from the forward pass. type = [float32].
361
363
* @param[in] softmax_sum Tensor representing the intermediate calculation result of softmax op from the forward pass. type = [float32].
362
364
* @param[in] softmax_out Tensor representing the intermediate calculation result of softmax op from the forward pass. type =[float32].
365
+
* @param[in] max_seqlen_q Maximum sequence length for query.
366
+
* @param[in] max_seqlen_kv Maximum sequence length for key and value.
363
367
* @param[in] p_dropout Dropout probability.
364
368
* @param[in] softmax_scale The scaling of qk^T before applying softmax. By default, softmax\_scale=\frac{1}{\sqrt{d_k}}
365
369
* @param[out] grad_q The gradient of the query tensor. shape = [total_q, head_num, head_dim], where total_q = total number of query tokens in the batch. type =
0 commit comments