Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions ring_flash_attn/adapters/hf_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,18 @@
import torch.distributed as dist
import transformers
import transformers.modeling_flash_attention_utils
from transformers.modeling_flash_attention_utils import (
_flash_supports_window_size,
is_flash_attn_greater_or_equal,
)
try:
from transformers.modeling_flash_attention_utils import (
_flash_supports_window,
is_flash_attn_greater_or_equal,
)
except ImportError:
# transformers <= 4.53.x
from transformers.modeling_flash_attention_utils import (
_flash_supports_window_size as _flash_supports_window,
is_flash_attn_greater_or_equal,
)

from ..llama3_flash_attn_varlen import (
llama3_flash_attn_varlen_func,
llama3_flash_attn_prepare_cu_seqlens,
Expand Down Expand Up @@ -111,7 +119,7 @@ def _flash_attention_forward(

# Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
use_sliding_windows = (
_flash_supports_window_size
_flash_supports_window
and sliding_window is not None
and key_states.shape[1] > sliding_window
)
Expand Down