diff --git a/sgl-kernel/python/sgl_kernel/flash_attn.py b/sgl-kernel/python/sgl_kernel/flash_attn.py index e849d0df002..5a443ab43db 100644 --- a/sgl-kernel/python/sgl_kernel/flash_attn.py +++ b/sgl-kernel/python/sgl_kernel/flash_attn.py @@ -144,10 +144,6 @@ def flash_attn_with_kvcache( logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). """ - if not is_fa3_supported(): - raise NotImplementedError( - "flash_attn at sgl-kernel is only supported on sm90 and cu123 above" - ) assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension" assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension" if softmax_scale is None: