From ef672d2389f62ac639daf4230862e60e19302ea6 Mon Sep 17 00:00:00 2001 From: Younes Belkada Date: Wed, 13 Mar 2024 07:52:09 +0000 Subject: [PATCH] fix fix copies --- src/transformers/models/gptj/modeling_gptj.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index c495d281db5d..144dbba05527 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -454,7 +454,7 @@ def _flash_attention_forward( attention_mask (`torch.Tensor`): The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the position of padding tokens and 1 for the position of non-padding tokens. - dropout (`int`, *optional*): + dropout (`float`): Attention dropout softmax_scale (`float`, *optional*): The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)