diff --git a/yunchang/ring/ring_flashinfer_attn.py b/yunchang/ring/ring_flashinfer_attn.py index e8f8bed..b1c224a 100644 --- a/yunchang/ring/ring_flashinfer_attn.py +++ b/yunchang/ring/ring_flashinfer_attn.py @@ -4,9 +4,11 @@ # from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward from .utils import RingComm, update_out_and_lse from yunchang.kernels import select_flash_attn_impl, AttnType +from yunchang.globals import HAS_FLASHINFER import torch.utils.cpp_extension as torch_cpp_ext -torch_cpp_ext._get_cuda_arch_flags() +if HAS_FLASHINFER: + torch_cpp_ext._get_cuda_arch_flags() def ring_flashinfer_attn_forward(