Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion yunchang/ring/ring_flashinfer_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
# from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward
from .utils import RingComm, update_out_and_lse
from yunchang.kernels import select_flash_attn_impl, AttnType
from yunchang.globals import HAS_FLASHINFER
import torch.utils.cpp_extension as torch_cpp_ext

torch_cpp_ext._get_cuda_arch_flags()
if HAS_FLASHINFER:
torch_cpp_ext._get_cuda_arch_flags()


def ring_flashinfer_attn_forward(
Expand Down