Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch_xla.experimental.custom_kernel.flash_attention output does not match F.scaled_dot_product_attention on TPU #8869

Open
NickLucche opened this issue Mar 21, 2025 · 2 comments
Labels
pallas pytorch divergence XLA behavior doesn't match Pytorch eager frontend xla:tpu TPU specific issues and PRs

Comments

@NickLucche
Copy link

🐛 Bug

Hey, I have found consistent mismatch between the output of the flash_attention impl wrt torch F.scaled_dot_product_attention for the default non-causal case.
Is this still within the accuracy margin you were targeting..?

Thanks for your help!

To Reproduce

scale = 1.0
for _ in range(10):
    q = torch.randn(1, 16, 577, 64, dtype=torch.bfloat16).to(xm.xla_device())
    k = torch.randn(1, 16, 577, 64, dtype=torch.bfloat16).to(xm.xla_device())
    v = torch.randn(1, 16, 577, 64, dtype=torch.bfloat16).to(xm.xla_device())

    output_torch = F.scaled_dot_product_attention(q, k, v, scale=scale)
    xm.mark_step()
    output = torch_xla.experimental.custom_kernel.flash_attention(q, k, v, sm_scale=scale)
    xm.mark_step()
    try:
        torch.testing.assert_close(output, output_torch)
    except AssertionError as e:
        print(e, '\n\n')

Outputs:

Mismatched elements: 579871 / 590848 (98.1%)
Greatest absolute difference: 0.0081862211227417 at index (0, 8, 453, 21) (up to 1e-05 allowed)
Greatest relative difference: 144.22958374023438 at index (0, 8, 379, 43) (up to 1.3e-06 allowed) 


Tensor-likes are not close!

Mismatched elements: 579790 / 590848 (98.1%)
Greatest absolute difference: 0.008110404014587402 at index (0, 4, 502, 13) (up to 1e-05 allowed)
Greatest relative difference: 850.0736083984375 at index (0, 3, 489, 60) (up to 1.3e-06 allowed) 


Tensor-likes are not close!

Mismatched elements: 579248 / 590848 (98.0%)
Greatest absolute difference: 0.008902788162231445 at index (0, 9, 145, 7) (up to 1e-05 allowed)
Greatest relative difference: 589.856201171875 at index (0, 15, 274, 8) (up to 1.3e-06 allowed) 


Tensor-likes are not close!

Mismatched elements: 579571 / 590848 (98.1%)
Greatest absolute difference: 0.010603189468383789 at index (0, 8, 3, 49) (up to 1e-05 allowed)
Greatest relative difference: 332.93280029296875 at index (0, 0, 367, 32) (up to 1.3e-06 allowed) 

Environment

  • Reproducible on XLA backend [CPU/TPU/CUDA]: TPU v6
  • torch_xla version:
Name: torch
Version: 2.8.0
Location: /home/nick/vllm/.venv/lib/python3.11/site-packages
Requires: filelock, fsspec, jinja2, networkx, sympy, typing-extensions
Required-by: compressed-tensors, outlines, xgrammar
---
Name: torch-xla
Version: 2.8.0+git4190fc0
Location: /home/nick/vllm/.venv/lib/python3.11/site-packages
Requires: absl-py, numpy, pyyaml, requests
Required-by:
@ysiraichi ysiraichi added pytorch divergence XLA behavior doesn't match Pytorch eager frontend pallas xla:tpu TPU specific issues and PRs labels Mar 24, 2025
@ysiraichi
Copy link
Collaborator

Thank you for filing the issue.
cc @vanbasten23 @bhavya01 @zpcore @qihqi

@qihqi
Copy link
Collaborator

qihqi commented Mar 24, 2025

You can add these 2 lines before your test to make it a little more accurate:

  jax.config.update('jax_enable_x64', True)
  jax.config.update('jax_default_matmul_precision', 'highest')

The flash attention pallas implementation comes from Jax directly.

Their unit tests uses 2e-3 for accuracy: https://github.com/jax-ml/jax/blob/main/jax/experimental/mosaic/gpu/examples/flash_attention.py#L597

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pallas pytorch divergence XLA behavior doesn't match Pytorch eager frontend xla:tpu TPU specific issues and PRs
Projects
None yet
Development

No branches or pull requests

3 participants