Skip to content

Fix PyTorch 2.10 ABI compatibility and build logging#2256

Closed
ussoewwin wants to merge 2 commits intoDao-AILab:mainfrom
ussoewwin:fix-pytorch-2.10-abi
Closed

Fix PyTorch 2.10 ABI compatibility and build logging#2256
ussoewwin wants to merge 2 commits intoDao-AILab:mainfrom
ussoewwin:fix-pytorch-2.10-abi

Conversation

@ussoewwin
Copy link
Copy Markdown

FlashAttention PyTorch 2.10+ Compatibility Fixes

This Pull Request addresses build and runtime compatibility issues with PyTorch 2.10 (and upcoming 2.11) on Windows (and potentially Linux).

Problem Description

Building FlashAttention with PyTorch 2.10+ (specifically 2.11.0a0 development builds) typically succeeds, but results in runtime errors or DLL load failures when importing the extension.

Key issues identified:

  1. Header Inclusion: The original csrc/flash_attn/flash_api.cpp deliberately includes <torch/python.h> instead of <torch/extension.h> to reduce compilation time. However, in newer PyTorch versions (2.10+), this configuration seems insufficient for full ABI compatibility for extensions, leading to missing symbols or mismatched definitions at runtime.
  2. Build Environment (Windows): setup.py and Windows build scripts lacked robustness against certain environment variable configurations.

Recommended Changes for PR

1. C++ Extension Header (csrc/flash_attn/flash_api.cpp)

  • Change: Replaced #include <torch/python.h> with #include <torch/extension.h>.
  • Reason: While the original code avoided this header to save compilation time, <torch/extension.h> is the standard and recommended header for C++ extensions to ensure ABI compatibility (_GLIBCXX_USE_CXX11_ABI, etc.). The runtime stability gained for PyTorch 2.10+ outweighs the minor increase in compilation time.

2. Build Script Improvements (setup.py)

  • Change: Added build-time logging for:
    • PyTorch Version
    • CUDA Version
    • _GLIBCXX_USE_CXX11_ABI status
  • Reason: This provides critical context in build logs without affecting the build process itself.

Verification

The fixes were verified on the following environment:

  • OS: Windows (x64)
  • PyTorch: 2.10.0+cu130 (Nightly/Dev 2.11.0a0)
  • CUDA: 12.x
  • FlashAttention: v2.8.3

Verification Script (test_flash_attn.py)

A simple forward pass test was performed to confirm stability:

import torch
import flash_attn

print("Testing FlashAttention forward pass...")
q = torch.randn(2, 128, 8, 64, device='cuda', dtype=torch.float16)
k = torch.randn(2, 128, 8, 64, device='cuda', dtype=torch.float16)
v = torch.randn(2, 128, 8, 64, device='cuda', dtype=torch.float16)

try:
    out = flash_attn.flash_attn_func(q, k, v)
    print("Success: Flash Attention forward pass executed.")
    print(f"Output shape: {out.shape}")
    assert not torch.isnan(out).any(), "Output contains NaNs!"
except Exception as e:
    print(f"Failed: {e}")

@janeyx99
Copy link
Copy Markdown
Contributor

janeyx99 commented Feb 12, 2026

@ussoewwin ideally FA2 can also follow the steps of hopper FA3 and be ABI stable with both CPython and libtorch

relevant PRs: #1662 and #1791

@ussoewwin
Copy link
Copy Markdown
Author

ussoewwin commented Feb 13, 2026

@janeyx99 Thanks for your technical advice.

I have amended the code based on the two PRs.

Summary of Changes:

  1. Migrated to TORCH_LIBRARY (Stable ABI):

    • In csrc/flash_attn/flash_api.cpp, I replaced PYBIND11_MODULE with TORCH_LIBRARY and TORCH_LIBRARY_IMPL.
    • This aligns with PyTorch's native operator registration mechanism and ensures better ABI compatibility.
  2. Operator Registration:

    • Defined operators (fwd, varlen_fwd, etc.) using the torch::library schema, matching the style of the referenced PRs.
  3. Python Interface Update:

    • Updated flash_attn/flash_attn_interface.py to invoke kernels via torch.ops.flash_attn_2_cuda and utilize torch.library.custom_op (for PyTorch 2.4+).

This implementation should now fully comply with the Stable ABI requirements.

@ussoewwin ussoewwin closed this Feb 13, 2026
@ussoewwin ussoewwin deleted the fix-pytorch-2.10-abi branch February 13, 2026 17:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants