torch_xla.experimental.custom_kernel.flash_attention
output does not match F.scaled_dot_product_attention
on TPU
#8869
Labels
pallas
pytorch divergence
XLA behavior doesn't match Pytorch eager frontend
xla:tpu
TPU specific issues and PRs
🐛 Bug
Hey, I have found consistent mismatch between the output of the
flash_attention
impl wrt torchF.scaled_dot_product_attention
for the default non-causal case.Is this still within the accuracy margin you were targeting..?
Thanks for your help!
To Reproduce
Outputs:
Environment
The text was updated successfully, but these errors were encountered: