diff --git a/yunchang/ring/ring_flashinfer_attn.py b/yunchang/ring/ring_flashinfer_attn.py index e8f8bed..e38f2c2 100644 --- a/yunchang/ring/ring_flashinfer_attn.py +++ b/yunchang/ring/ring_flashinfer_attn.py @@ -1,12 +1,8 @@ import torch import torch.distributed as dist -# from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward from .utils import RingComm, update_out_and_lse from yunchang.kernels import select_flash_attn_impl, AttnType -import torch.utils.cpp_extension as torch_cpp_ext - -torch_cpp_ext._get_cuda_arch_flags() def ring_flashinfer_attn_forward(