diff --git a/examples/flash_attention_4_example.py b/examples/flash_attention_4_example.py new file mode 100644 index 000000000000..54c7e6765603 --- /dev/null +++ b/examples/flash_attention_4_example.py @@ -0,0 +1,198 @@ +#!/usr/bin/env python3 +""" +Example: Using Flash Attention 4 (FA4) with HuggingFace Transformers + +This example demonstrates how to use FA4 with transformer models for inference. + +Requirements: + - CUDA-capable GPU with SM 8.0+ (Ampere, Hopper, or Blackwell) + - flash-attn package with CuTe DSL support + - transformers with FA4 integration + +Usage: + python examples/flash_attention_4_example.py +""" + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, is_flash_attn_4_available + + +def check_fa4_availability(): + """Check if FA4 is available on this system.""" + print("Checking FA4 availability...") + + if not torch.cuda.is_available(): + print("ERROR: CUDA not available. FA4 requires CUDA GPU.") + return False + + major, minor = torch.cuda.get_device_capability() + compute_cap = major * 10 + minor + + print(f" GPU: {torch.cuda.get_device_name(0)}") + print(f" Compute Capability: SM {major}.{minor}") + + if compute_cap < 80: + print(f" ERROR: FA4 requires SM 8.0+, found SM {major}.{minor}") + return False + + if compute_cap < 90: + print(f" NOTE: FA4 is optimized for SM 9.0+") + print(f" Your GPU will have reduced optimizations") + + fa4_available = is_flash_attn_4_available() + print(f" FA4 Available: {fa4_available}") + + if not fa4_available: + print("\n FA4 not available. Possible reasons:") + print(" 1. flash-attn package not installed") + print(" 2. flash-attn version doesn't include CuTe DSL") + print(" 3. GPU compute capability insufficient") + print("\n Install with: pip install flash-attn --upgrade") + + return fa4_available + + +def example_basic_inference(): + """Example 1: Basic inference with FA4.""" + print("\n" + "=" * 70) + print("Example 1: Basic Inference with FA4") + print("=" * 70) + + model_name = "gpt2" + print(f"\nLoading model: {model_name}") + + # Load model with explicit FA4 attention + model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=torch.bfloat16, + device_map="auto", + attn_implementation="flash_attention_4", # Explicitly request FA4 + ) + + tokenizer = AutoTokenizer.from_pretrained(model_name) + + # Verify FA4 is being used + print(f"Attention implementation: {model.config._attn_implementation}") + + # Generate text + prompt = "The future of artificial intelligence is" + inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + + print(f"\nPrompt: {prompt}") + print("Generating...") + + with torch.inference_mode(): + outputs = model.generate(**inputs, max_new_tokens=50, do_sample=True, temperature=0.7) + + generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) + print(f"\nGenerated: {generated_text}") + + +def example_auto_selection(): + """Example 2: Automatic FA4 selection.""" + print("\n" + "=" * 70) + print("Example 2: Automatic FA4 Selection") + print("=" * 70) + + model_name = "gpt2" + print(f"\nLoading model: {model_name}") + + # Load without specifying attention implementation + # Will auto-select best available (FA4 > FA3 > FA2 > SDPA) + model = AutoModelForCausalLM.from_pretrained( + model_name, torch_dtype=torch.bfloat16, device_map="auto" + # attn_implementation not specified - will auto-select + ) + + print(f"Auto-selected attention: {model.config._attn_implementation}") + + tokenizer = AutoTokenizer.from_pretrained(model_name) + + prompt = "Machine learning is revolutionizing" + inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + + print(f"\nPrompt: {prompt}") + print("Generating...") + + with torch.inference_mode(): + outputs = model.generate(**inputs, max_new_tokens=50) + + generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) + print(f"\nGenerated: {generated_text}") + + +def example_compare_implementations(): + """Example 3: Compare FA4 vs other implementations.""" + print("\n" + "=" * 70) + print("Example 3: Comparing Attention Implementations") + print("=" * 70) + + model_name = "gpt2" + prompt = "The quick brown fox" + + implementations = ["flash_attention_4", "flash_attention_2", "sdpa", "eager"] + + for impl in implementations: + print(f"\n--- Testing {impl} ---") + + try: + # Load model with specific implementation + model = AutoModelForCausalLM.from_pretrained( + model_name, torch_dtype=torch.bfloat16, device_map="auto", attn_implementation=impl + ) + + tokenizer = AutoTokenizer.from_pretrained(model_name) + inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + + # Measure inference time + import time + + torch.cuda.synchronize() + start = time.time() + + with torch.inference_mode(): + outputs = model.generate(**inputs, max_new_tokens=20) + + torch.cuda.synchronize() + elapsed = time.time() - start + + print(f" Implementation: {model.config._attn_implementation}") + print(f" Time: {elapsed:.4f}s") + + except Exception as e: + print(f" ERROR: {e}") + + +def main(): + print("\n" + "=" * 70) + print(" Flash Attention 4 (FA4) Example") + print("=" * 70) + + # Check FA4 availability + if not check_fa4_availability(): + print("\nFA4 not available. Examples will use fallback implementations.") + print("Install flash-attn to enable FA4.") + + # Run examples + try: + example_basic_inference() + except Exception as e: + print(f"\nExample 1 failed: {e}") + + try: + example_auto_selection() + except Exception as e: + print(f"\nExample 2 failed: {e}") + + try: + example_compare_implementations() + except Exception as e: + print(f"\nExample 3 failed: {e}") + + print("\n" + "=" * 70) + print(" Examples Complete") + print("=" * 70 + "\n") + + +if __name__ == "__main__": + main() diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 5e31bdae01eb..ee49a24a66d4 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -22,6 +22,7 @@ from .utils import ( is_flash_attn_2_available, is_flash_attn_3_available, + is_flash_attn_4_available, is_flash_attn_greater_or_equal_2_10, is_torch_npu_available, is_torch_xpu_available, @@ -34,7 +35,8 @@ # TODO Deprecate when all models have the attention interface def flash_attn_supports_top_left_mask(): - if is_flash_attn_3_available(): + # FA4 doesn't support top-left mask (similar to FA3) + if is_flash_attn_4_available() or is_flash_attn_3_available(): return False if is_flash_attn_2_available(): return not is_flash_attn_greater_or_equal_2_10() @@ -47,7 +49,8 @@ def flash_attn_supports_top_left_mask(): # TODO Deprecate when all models have the attention interface def is_flash_attn_available(): return ( - is_flash_attn_3_available() + is_flash_attn_4_available() + or is_flash_attn_3_available() or is_flash_attn_2_available() or is_torch_npu_available() or is_torch_xpu_available() @@ -82,10 +85,17 @@ def _lazy_imports(implementation: Optional[str]): """ is_fa2 = is_flash_attn_2_available() is_fa3 = is_flash_attn_3_available() + is_fa4 = is_flash_attn_4_available() pad_input, unpad_input = _pad_input, _unpad_input - if (implementation == "flash_attention_2" and is_fa2) or (implementation is None and is_fa2 and not is_fa3): + # Priority order when implementation=None: FA4 > FA3 > FA2 + # FA4 is preferred on supported hardware for optimal performance + if implementation == "flash_attention_4" and is_fa4: + # FA4 uses CuTe DSL implementation from flash_attn.cute submodule + from flash_attn.cute import flash_attn_func, flash_attn_varlen_func + # FA4 doesn't provide pad/unpad functions, use our implementations + elif (implementation == "flash_attention_2" and is_fa2) or (implementation is None and is_fa2 and not is_fa3 and not is_fa4): from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.bert_padding import pad_input, unpad_input elif is_torch_npu_available(): @@ -94,8 +104,11 @@ def _lazy_imports(implementation: Optional[str]): from .integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_func from .integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_varlen_func else: - if implementation == "flash_attention_3" or (implementation is None and is_fa3): + if implementation == "flash_attention_3" or (implementation is None and is_fa3 and not is_fa4): from flash_attn_interface import flash_attn_func, flash_attn_varlen_func + elif implementation is None and is_fa4: + # Auto-select FA4 when no specific implementation requested and FA4 available + from flash_attn.cute import flash_attn_func, flash_attn_varlen_func # Kernels fallback else: flash_attn_func = getattr(implementation, "flash_attn_func", None) @@ -130,6 +143,24 @@ def _lazy_define_process_function(flash_function): return partial(_process_flash_attention_kwargs, supports_mapping=supports_mapping) +def _is_using_fa4(flash_varlen_fn) -> bool: + """ + Detect if we're using FA4 based on function signature. + + FA4 doesn't accept max_seqlen_q/k parameters (calculates internally from cu_seqlens). + This function checks the signature to determine the FA version. + + Args: + flash_varlen_fn: The flash attention varlen function to inspect. + + Returns: + bool: True if using FA4, False otherwise. + """ + flash_parameters = inspect.signature(flash_varlen_fn).parameters + # FA4 doesn't have max_seqlen_q parameter + return "max_seqlen_q" not in flash_parameters + + def lazy_import_flash_attention(implementation: Optional[str], force_import: Optional[bool] = False): """ Lazily import flash attention and return the respective functions + flags. @@ -467,6 +498,9 @@ def _process_flash_attention_kwargs( softcap: Optional[float] = None, deterministic: Optional[bool] = None, s_aux: Optional[torch.Tensor] = None, + learnable_sink: Optional[torch.Tensor] = None, + num_splits: int = 1, + pack_gqa: Optional[bool] = None, supports_mapping: Optional[dict[str, bool]] = None, **kwargs, ): @@ -497,6 +531,12 @@ def _process_flash_attention_kwargs( Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled. s_aux (`torch.Tensor`, *optional*): Attention sink auxiliary that adds a `bias` to the attention calculation via an additional head. + learnable_sink (`torch.Tensor`, *optional*): + FA4-specific: Learnable sink tokens for attention. Shape: (batch, num_sink_tokens, head_dim). + num_splits (`int`, *optional*): + FA4-specific: Number of splits for split-KV mode. Default: 1 (no splitting). + pack_gqa (`bool`, *optional*): + FA4-specific: Whether to use packed GQA optimization. If None, auto-detected. Return: flash_kwargs (`dict`): A dict of kwargs that are requested and supported. @@ -529,6 +569,16 @@ def _process_flash_attention_kwargs( if supports_mapping["s_aux"] and s_aux is not None: flash_kwargs["s_aux"] = s_aux + # FA4-specific parameters + if supports_mapping.get("learnable_sink", False) and learnable_sink is not None: + flash_kwargs["learnable_sink"] = learnable_sink + + if supports_mapping.get("num_splits", False) and num_splits > 1: + flash_kwargs["num_splits"] = num_splits + + if supports_mapping.get("pack_gqa", False) and pack_gqa is not None: + flash_kwargs["pack_gqa"] = pack_gqa + return flash_kwargs @@ -608,6 +658,9 @@ def _flash_attention_forward( kwarg is not None for kwarg in (cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k) ) + # Detect if we're using FA4 (which doesn't accept max_seqlen parameters) + is_fa4 = _is_using_fa4(flash_varlen_fn) + # Contains at least one padding token in the sequence if attention_mask is not None: q, k, v, indices_q, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = _upad_input( @@ -619,16 +672,28 @@ def _flash_attention_forward( if "mps" in str(q.device): cu_seq_lens_k = cu_seq_lens_k.clone() - out_unpad = flash_varlen_fn( - q, - k, - v, - cu_seqlens_q=cu_seq_lens_q, - cu_seqlens_k=cu_seq_lens_k, - max_seqlen_q=max_length_q, - max_seqlen_k=max_length_k, - **flash_kwargs, - ) + # FA4 calculates max_seqlen internally from cu_seqlens, so we don't pass it + if is_fa4: + out_unpad = flash_varlen_fn( + q, + k, + v, + cu_seqlens_q=cu_seq_lens_q, + cu_seqlens_k=cu_seq_lens_k, + **flash_kwargs, + ) + else: + # FA2/FA3 require max_seqlen parameters + out_unpad = flash_varlen_fn( + q, + k, + v, + cu_seqlens_q=cu_seq_lens_q, + cu_seqlens_k=cu_seq_lens_k, + max_seqlen_q=max_length_q, + max_seqlen_k=max_length_k, + **flash_kwargs, + ) if isinstance(out_unpad, tuple): out_unpad = out_unpad[0] @@ -650,16 +715,28 @@ def _flash_attention_forward( if "mps" in str(q.device): cu_seq_lens_k = cu_seq_lens_k.clone() - out = flash_varlen_fn( - q, - k, - v, - cu_seqlens_q=cu_seq_lens_q, - cu_seqlens_k=cu_seq_lens_k, - max_seqlen_q=max_length_q, - max_seqlen_k=max_length_k, - **flash_kwargs, - ) + # FA4 calculates max_seqlen internally from cu_seqlens, so we don't pass it + if is_fa4: + out = flash_varlen_fn( + q, + k, + v, + cu_seqlens_q=cu_seq_lens_q, + cu_seqlens_k=cu_seq_lens_k, + **flash_kwargs, + ) + else: + # FA2/FA3 require max_seqlen parameters + out = flash_varlen_fn( + q, + k, + v, + cu_seqlens_q=cu_seq_lens_q, + cu_seqlens_k=cu_seq_lens_k, + max_seqlen_q=max_length_q, + max_seqlen_k=max_length_k, + **flash_kwargs, + ) if isinstance(out, tuple): out = out[0] diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 748d7af639af..4210f1f74761 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4650,6 +4650,7 @@ class AttentionInterface(GeneralInterface): # Class instance object, so that a call to `register` can be reflected into all other files correctly, even if # a new instance is created (in order to locally override a given function) _global_mapping = { + "flash_attention_4": flash_attention_forward, "flash_attention_3": flash_attention_forward, "flash_attention_2": flash_attention_forward, "flex_attention": flex_attention_forward, diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index bdaf22e81478..9d5456775231 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -89,6 +89,7 @@ is_fbgemm_gpu_available, is_flash_attn_2_available, is_flash_attn_3_available, + is_flash_attn_4_available, is_flute_available, is_fp_quant_available, is_fsdp_available, @@ -619,6 +620,16 @@ def require_flash_attn_3(test_case): return unittest.skipUnless(is_flash_attn_3_available(), "test requires Flash Attention 3")(test_case) +def require_flash_attn_4(test_case): + """ + Decorator marking a test that requires Flash Attention 4 (CuTe DSL). + + These tests are skipped when Flash Attention 4 isn't installed or when + CUDA compute capability is insufficient (requires SM 8.0+, optimized for SM 9.0+). + """ + return unittest.skipUnless(is_flash_attn_4_available(), "test requires Flash Attention 4")(test_case) + + def require_read_token(test_case): """ A decorator that loads the HF token for tests that require to load gated models. diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 38b5db8f4893..19e14933c4b3 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -143,6 +143,7 @@ is_fbgemm_gpu_available, is_flash_attn_2_available, is_flash_attn_3_available, + is_flash_attn_4_available, is_flash_attn_greater_or_equal, is_flash_attn_greater_or_equal_2_10, is_flute_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index bf2fba35fd0e..34cc5267c869 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -857,6 +857,51 @@ def is_flash_attn_3_available() -> bool: return is_torch_cuda_available() and _is_package_available("flash_attn_3") +@lru_cache +def is_flash_attn_4_available() -> bool: + """ + Check if Flash Attention 4 (CuTe DSL implementation) is available. + + FA4 is distributed within the flash_attn package under the 'cute' submodule + and requires CUDA compute capability >= 9.0 (Hopper or Blackwell GPUs). + + Returns: + bool: True if FA4 is available and hardware requirements are met. + """ + # FA4 requires CUDA and flash_attn package + if not is_torch_cuda_available(): + return False + + is_available, flash_attn_version = _is_package_available("flash_attn", return_version=True) + if not is_available: + return False + + try: + # FA4 is available via flash_attn.cute submodule + # We verify by attempting to import the cute interface + from flash_attn.cute import flash_attn_func + + # FA4 is optimized for SM 9.0+ (Hopper/Blackwell) + # It can run on SM 8.0 (Ampere) but with limited optimizations + import torch + + if torch.cuda.is_available(): + # Get compute capability of the default device + major, minor = torch.cuda.get_device_capability() + compute_cap = major * 10 + minor + + # FA4 requires SM 8.0+ minimum, optimized for SM 9.0+ + # We allow SM 8.0+ but document that SM 9.0+ is recommended + if compute_cap >= 80: + return True + + return False + + except (ImportError, AttributeError, RuntimeError): + # Import failed - FA4 not available + return False + + @lru_cache def is_flash_attn_greater_or_equal_2_10() -> bool: _, flash_attn_version = _is_package_available("flash_attn", return_version=True) diff --git a/tests/test_flash_attention_4.py b/tests/test_flash_attention_4.py new file mode 100644 index 000000000000..4a9ef5546fed --- /dev/null +++ b/tests/test_flash_attention_4.py @@ -0,0 +1,232 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for Flash Attention 4 (CuTe DSL) integration. + +Usage: + pytest tests/test_flash_attention_4.py -v + + # Run with specific GPU + CUDA_VISIBLE_DEVICES=0 pytest tests/test_flash_attention_4.py -v +""" + +import unittest + +import torch + +from transformers import is_flash_attn_4_available +from transformers.testing_utils import require_flash_attn_4, require_torch_gpu + + +class FlashAttention4DetectionTest(unittest.TestCase): + """Test FA4 detection without requiring GPU.""" + + def test_detection_function_exists(self): + """Verify is_flash_attn_4_available is callable.""" + self.assertTrue(callable(is_flash_attn_4_available)) + + def test_detection_returns_bool(self): + """Verify detection returns boolean.""" + result = is_flash_attn_4_available() + self.assertIsInstance(result, bool) + + +@require_torch_gpu +@require_flash_attn_4 +class FlashAttention4IntegrationTest(unittest.TestCase): + """Integration tests requiring GPU and FA4.""" + + def setUp(self): + """Set up test fixtures.""" + self.device = torch.device("cuda:0") + self.dtype = torch.bfloat16 + self.batch_size = 2 + self.seq_len = 128 + self.num_heads = 8 + self.head_dim = 64 + + def test_fa4_import(self): + """Test FA4 can be imported.""" + try: + from flash_attn.cute import flash_attn_func, flash_attn_varlen_func + + self.assertIsNotNone(flash_attn_func) + self.assertIsNotNone(flash_attn_varlen_func) + except ImportError as e: + self.fail(f"Failed to import FA4: {e}") + + def test_fa4_basic_forward(self): + """Test basic FA4 forward pass.""" + from flash_attn.cute import flash_attn_func + + q = torch.randn( + self.batch_size, self.seq_len, self.num_heads, self.head_dim, device=self.device, dtype=self.dtype + ) + k = torch.randn( + self.batch_size, self.seq_len, self.num_heads, self.head_dim, device=self.device, dtype=self.dtype + ) + v = torch.randn( + self.batch_size, self.seq_len, self.num_heads, self.head_dim, device=self.device, dtype=self.dtype + ) + + try: + out = flash_attn_func(q, k, v, causal=False) + self.assertEqual(out.shape, (self.batch_size, self.seq_len, self.num_heads, self.head_dim)) + self.assertEqual(out.dtype, self.dtype) + except Exception as e: + self.fail(f"FA4 forward pass failed: {e}") + + def test_fa4_causal_attention(self): + """Test FA4 with causal masking.""" + from flash_attn.cute import flash_attn_func + + q = torch.randn( + self.batch_size, self.seq_len, self.num_heads, self.head_dim, device=self.device, dtype=self.dtype + ) + k = torch.randn( + self.batch_size, self.seq_len, self.num_heads, self.head_dim, device=self.device, dtype=self.dtype + ) + v = torch.randn( + self.batch_size, self.seq_len, self.num_heads, self.head_dim, device=self.device, dtype=self.dtype + ) + + try: + out = flash_attn_func(q, k, v, causal=True) + self.assertEqual(out.shape, (self.batch_size, self.seq_len, self.num_heads, self.head_dim)) + except Exception as e: + self.fail(f"FA4 causal attention failed: {e}") + + def test_fa4_varlen_no_max_seqlen(self): + """Test FA4 varlen function does not accept max_seqlen parameters.""" + from flash_attn.cute import flash_attn_varlen_func + import inspect + + sig = inspect.signature(flash_attn_varlen_func) + params = list(sig.parameters.keys()) + + # Verify FA4 API: no max_seqlen_q/k parameters + self.assertNotIn("max_seqlen_q", params, "FA4 should not have max_seqlen_q parameter") + self.assertNotIn("max_seqlen_k", params, "FA4 should not have max_seqlen_k parameter") + + # Verify FA4 has cu_seqlens parameters + self.assertIn("cu_seqlens_q", params, "FA4 should have cu_seqlens_q parameter") + self.assertIn("cu_seqlens_k", params, "FA4 should have cu_seqlens_k parameter") + + def test_fa4_varlen_forward(self): + """Test FA4 varlen forward pass.""" + from flash_attn.cute import flash_attn_varlen_func + + # Create packed sequences: [seq1_len=50, seq2_len=78] + total_tokens = 128 + cu_seqlens = torch.tensor([0, 50, 128], dtype=torch.int32, device=self.device) + + q = torch.randn(total_tokens, self.num_heads, self.head_dim, device=self.device, dtype=self.dtype) + k = torch.randn(total_tokens, self.num_heads, self.head_dim, device=self.device, dtype=self.dtype) + v = torch.randn(total_tokens, self.num_heads, self.head_dim, device=self.device, dtype=self.dtype) + + try: + # FA4 calculates max_seqlen internally, no need to pass it + out = flash_attn_varlen_func(q, k, v, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, causal=False) + self.assertEqual(out.shape, (total_tokens, self.num_heads, self.head_dim)) + except Exception as e: + self.fail(f"FA4 varlen forward failed: {e}") + + def test_hf_fa4_integration(self): + """Test HF's FA4 integration via lazy_import_flash_attention.""" + from transformers.modeling_flash_attention_utils import lazy_import_flash_attention, _is_using_fa4 + + # Test explicit FA4 selection + (flash_fn, flash_varlen_fn, pad_fn, unpad_fn), process_kwargs_fn = lazy_import_flash_attention( + "flash_attention_4" + ) + + self.assertIsNotNone(flash_fn) + self.assertIsNotNone(flash_varlen_fn) + + # Verify we're using FA4 (no max_seqlen_q parameter) + is_fa4 = _is_using_fa4(flash_varlen_fn) + self.assertTrue(is_fa4, "Should detect FA4 via introspection") + + def test_hf_fa4_auto_selection(self): + """Test HF auto-selects FA4 when available.""" + from transformers.modeling_flash_attention_utils import lazy_import_flash_attention, _is_using_fa4 + + # Test auto-selection (implementation=None) + (flash_fn, flash_varlen_fn, pad_fn, unpad_fn), process_kwargs_fn = lazy_import_flash_attention(None) + + # Should select FA4 if available + is_fa4 = _is_using_fa4(flash_varlen_fn) + # This will be True if FA4 is the highest priority available + self.assertIsInstance(is_fa4, bool) + + +@require_torch_gpu +@require_flash_attn_4 +class FlashAttention4ParameterTest(unittest.TestCase): + """Test FA4-specific parameters.""" + + def setUp(self): + """Set up test fixtures.""" + self.device = torch.device("cuda:0") + self.dtype = torch.bfloat16 + self.batch_size = 2 + self.seq_len = 128 + self.num_heads = 8 + self.head_dim = 64 + + def test_softcap_parameter(self): + """Test FA4 softcap parameter.""" + from flash_attn.cute import flash_attn_func + + q = torch.randn( + self.batch_size, self.seq_len, self.num_heads, self.head_dim, device=self.device, dtype=self.dtype + ) + k = torch.randn( + self.batch_size, self.seq_len, self.num_heads, self.head_dim, device=self.device, dtype=self.dtype + ) + v = torch.randn( + self.batch_size, self.seq_len, self.num_heads, self.head_dim, device=self.device, dtype=self.dtype + ) + + try: + out = flash_attn_func(q, k, v, causal=False, softcap=30.0) + self.assertEqual(out.shape, (self.batch_size, self.seq_len, self.num_heads, self.head_dim)) + except Exception as e: + self.fail(f"FA4 softcap parameter failed: {e}") + + def test_window_size_parameter(self): + """Test FA4 sliding window attention.""" + from flash_attn.cute import flash_attn_func + + q = torch.randn( + self.batch_size, self.seq_len, self.num_heads, self.head_dim, device=self.device, dtype=self.dtype + ) + k = torch.randn( + self.batch_size, self.seq_len, self.num_heads, self.head_dim, device=self.device, dtype=self.dtype + ) + v = torch.randn( + self.batch_size, self.seq_len, self.num_heads, self.head_dim, device=self.device, dtype=self.dtype + ) + + try: + # FA4 uses (left, right) tuple for window_size + out = flash_attn_func(q, k, v, causal=True, window_size=(32, 32)) + self.assertEqual(out.shape, (self.batch_size, self.seq_len, self.num_heads, self.head_dim)) + except Exception as e: + self.fail(f"FA4 window_size parameter failed: {e}") + + +if __name__ == "__main__": + unittest.main() diff --git a/validate_fa4.py b/validate_fa4.py new file mode 100644 index 000000000000..b0343d004619 --- /dev/null +++ b/validate_fa4.py @@ -0,0 +1,255 @@ +#!/usr/bin/env python3 +""" +Quick FA4 validation script for SSH GPU access. + +Usage: + python validate_fa4.py +""" + +import sys + +import torch + + +def print_section(title): + print(f"\n{'=' * 70}") + print(f" {title}") + print('=' * 70) + + +def check_cuda(): + print_section("CUDA Environment") + if not torch.cuda.is_available(): + print("ERROR: CUDA not available") + return False + + print(f"CUDA available: {torch.cuda.is_available()}") + print(f"CUDA version: {torch.version.cuda}") + print(f"GPU count: {torch.cuda.device_count()}") + print(f"Current device: {torch.cuda.current_device()}") + print(f"Device name: {torch.cuda.get_device_name(0)}") + + major, minor = torch.cuda.get_device_capability() + compute_cap = major * 10 + minor + print(f"Compute capability: SM {major}.{minor} ({compute_cap})") + + if compute_cap < 80: + print(f"WARNING: FA4 requires SM 8.0+, you have SM {major}.{minor}") + return False + + if compute_cap < 90: + print(f"NOTE: FA4 is optimized for SM 9.0+ (Hopper/Blackwell)") + print(f" Your GPU (SM {major}.{minor}) will have limited optimizations") + + return True + + +def check_transformers(): + print_section("Transformers Installation") + try: + import transformers + + print(f"Transformers version: {transformers.__version__}") + print(f"Transformers path: {transformers.__file__}") + return True + except ImportError as e: + print(f"ERROR: Failed to import transformers: {e}") + return False + + +def check_flash_attn_package(): + print_section("Flash Attention Package") + try: + import flash_attn + + print(f"flash-attn installed: Yes") + if hasattr(flash_attn, "__version__"): + print(f"flash-attn version: {flash_attn.__version__}") + return True + except ImportError: + print("ERROR: flash-attn package not installed") + print("Install with: pip install flash-attn") + return False + + +def check_fa4_availability(): + print_section("FA4 Detection") + try: + from transformers import is_flash_attn_4_available + + fa4_available = is_flash_attn_4_available() + print(f"is_flash_attn_4_available(): {fa4_available}") + + if not fa4_available: + print("\nFA4 not available. Checking why...") + + # Try manual import + try: + from flash_attn.cute import flash_attn_func + + print(" - flash_attn.cute can be imported manually") + print(" - Detection function may have incorrect logic") + except ImportError as e: + print(f" - flash_attn.cute import failed: {e}") + print(" - FA4 (CuTe DSL) not available in installed flash-attn") + + return fa4_available + except Exception as e: + print(f"ERROR: Detection check failed: {e}") + return False + + +def test_fa4_import(): + print_section("FA4 Import Test") + try: + from flash_attn.cute import flash_attn_func, flash_attn_varlen_func + + print("Successfully imported:") + print(f" - flash_attn_func: {flash_attn_func}") + print(f" - flash_attn_varlen_func: {flash_attn_varlen_func}") + return True + except ImportError as e: + print(f"ERROR: Failed to import FA4: {e}") + return False + + +def test_fa4_signature(): + print_section("FA4 API Signature Check") + try: + import inspect + + from flash_attn.cute import flash_attn_varlen_func + + sig = inspect.signature(flash_attn_varlen_func) + params = list(sig.parameters.keys()) + + print(f"flash_attn_varlen_func parameters: {params}") + + has_max_seqlen = "max_seqlen_q" in params + has_cu_seqlens = "cu_seqlens_q" in params + + print(f"\nAPI check:") + print(f" - Has cu_seqlens_q: {has_cu_seqlens} (expected: True)") + print(f" - Has max_seqlen_q: {has_max_seqlen} (expected: False)") + + if has_cu_seqlens and not has_max_seqlen: + print("\nSUCCESS: FA4 API signature is correct") + return True + else: + print("\nWARNING: Unexpected API signature") + return False + + except Exception as e: + print(f"ERROR: Signature check failed: {e}") + return False + + +def test_fa4_basic_forward(): + print_section("FA4 Basic Forward Test") + try: + from flash_attn.cute import flash_attn_func + + batch_size = 2 + seq_len = 128 + num_heads = 8 + head_dim = 64 + dtype = torch.bfloat16 + device = torch.device("cuda:0") + + q = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) + k = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) + v = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) + + print(f"Input shapes: q={q.shape}, k={k.shape}, v={v.shape}") + + out = flash_attn_func(q, k, v, causal=False) + + print(f"Output shape: {out.shape}") + print(f"Output dtype: {out.dtype}") + + if out.shape == (batch_size, seq_len, num_heads, head_dim): + print("\nSUCCESS: FA4 forward pass works correctly") + return True + else: + print(f"\nERROR: Unexpected output shape") + return False + + except Exception as e: + print(f"ERROR: Forward pass failed: {e}") + import traceback + + traceback.print_exc() + return False + + +def test_hf_integration(): + print_section("HF Integration Test") + try: + from transformers.modeling_flash_attention_utils import _is_using_fa4, lazy_import_flash_attention + + print("Testing explicit FA4 selection...") + (flash_fn, flash_varlen_fn, pad_fn, unpad_fn), process_kwargs_fn = lazy_import_flash_attention( + "flash_attention_4" + ) + + is_fa4 = _is_using_fa4(flash_varlen_fn) + print(f" - _is_using_fa4(): {is_fa4} (expected: True)") + + print("\nTesting auto-selection...") + (flash_fn_auto, flash_varlen_fn_auto, _, _), _ = lazy_import_flash_attention(None) + + is_fa4_auto = _is_using_fa4(flash_varlen_fn_auto) + print(f" - Auto-selected FA4: {is_fa4_auto}") + + if is_fa4: + print("\nSUCCESS: HF integration works correctly") + return True + else: + print("\nWARNING: HF integration issue detected") + return False + + except Exception as e: + print(f"ERROR: HF integration test failed: {e}") + import traceback + + traceback.print_exc() + return False + + +def main(): + print("\n" + "=" * 70) + print(" Flash Attention 4 (FA4) Validation") + print("=" * 70) + + results = { + "CUDA": check_cuda(), + "Transformers": check_transformers(), + "flash-attn package": check_flash_attn_package(), + "FA4 detection": check_fa4_availability(), + "FA4 import": test_fa4_import(), + "FA4 API signature": test_fa4_signature(), + "FA4 forward pass": test_fa4_basic_forward(), + "HF integration": test_hf_integration(), + } + + print_section("Summary") + all_passed = True + for test_name, passed in results.items(): + status = "PASS" if passed else "FAIL" + symbol = "✓" if passed else "✗" + print(f"{symbol} {test_name}: {status}") + if not passed: + all_passed = False + + print("\n" + "=" * 70) + if all_passed: + print(" ALL CHECKS PASSED - FA4 IS READY TO USE") + else: + print(" SOME CHECKS FAILED - SEE ABOVE FOR DETAILS") + print("=" * 70 + "\n") + + sys.exit(0 if all_passed else 1) + + +if __name__ == "__main__": + main()