From a4aa2360218db1ec20fc5a31f9fe3e03a4121dcb Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 23 Jul 2025 07:14:50 -0400 Subject: [PATCH] use updated var from hf refactor --- ring_flash_attn/adapters/hf_adapter.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/ring_flash_attn/adapters/hf_adapter.py b/ring_flash_attn/adapters/hf_adapter.py index 8ad3f03..b2c95c6 100644 --- a/ring_flash_attn/adapters/hf_adapter.py +++ b/ring_flash_attn/adapters/hf_adapter.py @@ -6,10 +6,18 @@ import torch.distributed as dist import transformers import transformers.modeling_flash_attention_utils -from transformers.modeling_flash_attention_utils import ( - _flash_supports_window_size, - is_flash_attn_greater_or_equal, -) +try: + from transformers.modeling_flash_attention_utils import ( + _flash_supports_window, + is_flash_attn_greater_or_equal, + ) +except ImportError: + # transformers <= 4.53.x + from transformers.modeling_flash_attention_utils import ( + _flash_supports_window_size as _flash_supports_window, + is_flash_attn_greater_or_equal, + ) + from ..llama3_flash_attn_varlen import ( llama3_flash_attn_varlen_func, llama3_flash_attn_prepare_cu_seqlens, @@ -111,7 +119,7 @@ def _flash_attention_forward( # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length). use_sliding_windows = ( - _flash_supports_window_size + _flash_supports_window and sliding_window is not None and key_states.shape[1] > sliding_window )