diff --git a/setup.py b/setup.py index 866bbffc5a..bef32c36e2 100644 --- a/setup.py +++ b/setup.py @@ -176,7 +176,9 @@ def get_flash_attention2_nvcc_archs_flags(cuda_version: int): return [] # Figure out default archs to target DEFAULT_ARCHS_LIST = "" - if cuda_version >= 1108: + if cuda_version >= 1208: + DEFAULT_ARCHS_LIST = "8.0;8.6;9.0;10.0;12.0" + elif cuda_version >= 1108: DEFAULT_ARCHS_LIST = "8.0;8.6;9.0" elif cuda_version > 1100: DEFAULT_ARCHS_LIST = "8.0;8.6"