diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index 10036aeb9ca7..26d537e966cd 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -1087,6 +1087,7 @@ def flash_attention_non_xla(q: torch.Tensor, ) + @impl(XLA_LIB, "paged_attention", "XLA") def paged_attention_xla(q: torch.Tensor, k_pages: torch.Tensor,