@@ -150,20 +150,35 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
150150 # TODO(lucas): handle this more gracefully
151151 # Note: model_config may be None during testing
152152 if model_config is not None and model_config .use_mla :
153- # if `VLLM_ATTENTION_BACKEND` is not set and we are using MLA, then
154- # we default to FlashMLA backend, so we need to force the blocksize
155- # here
156- use_flashmla = (envs .VLLM_ATTENTION_BACKEND is None \
157- or envs .VLLM_ATTENTION_BACKEND == "FLASHMLA" )
153+ # If `VLLM_ATTENTION_BACKEND` is not set and we are using MLA,
154+ # then we default to FlashMLA backend for non-blackwell GPUs,
155+ # else we default to CutlassMLA. For each case, we force the
156+ # required block_size.
157+ use_flashmla = False
158+ use_cutlass_mla = False
159+
160+ if envs .VLLM_ATTENTION_BACKEND is None :
161+ # Default case
162+ if cls .is_device_capability (100 ):
163+ # Blackwell => Force CutlassMLA.
164+ use_cutlass_mla = True
165+ envs .VLLM_ATTENTION_BACKEND = "CUTLASS_MLA_VLLM_V1"
166+ else :
167+ # Not Blackwell
168+ use_flashmla = True
169+ else :
170+ # Forced case
171+ use_flashmla = (envs .VLLM_ATTENTION_BACKEND == "FLASHMLA" )
172+ use_cutlass_mla = (
173+ envs .VLLM_ATTENTION_BACKEND == "CUTLASS_MLA_VLLM_V1" )
174+
158175 from vllm .attention .ops .flashmla import is_flashmla_supported
159176 if use_flashmla and is_flashmla_supported ()[0 ] \
160177 and cache_config .block_size != 64 :
161178 cache_config .block_size = 64
162179 logger .info (
163180 "Forcing kv cache block size to 64 for FlashMLA backend." )
164181
165- use_cutlass_mla = (envs .VLLM_ATTENTION_BACKEND is not None \
166- and envs .VLLM_ATTENTION_BACKEND == "CUTLASS_MLA_VLLM_V1" )
167182 if use_cutlass_mla and cache_config .block_size != 128 :
168183 cache_config .block_size = 128
169184 logger .info ("Forcing kv cache block size to 128 for "
0 commit comments