diff --git a/vllm_omni/diffusion/attention/layer.py b/vllm_omni/diffusion/attention/layer.py index 0d3bf8b7da9..83cfe84d7b6 100644 --- a/vllm_omni/diffusion/attention/layer.py +++ b/vllm_omni/diffusion/attention/layer.py @@ -12,6 +12,7 @@ from vllm.logger import init_logger from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata +from vllm_omni.diffusion.attention.backends.sdpa import SDPABackend from vllm_omni.diffusion.attention.parallel import build_parallel_attention_strategy from vllm_omni.diffusion.attention.parallel.ring import RingParallelAttention from vllm_omni.diffusion.attention.selector import get_attn_backend @@ -45,6 +46,14 @@ def __init__( causal=causal, num_kv_heads=num_kv_heads, ) + # Instantiate fallback backend for float32 support + self.sdpa_fallback = SDPABackend.get_impl_cls()( + num_heads=num_heads, + head_size=head_size, + softmax_scale=softmax_scale, + causal=causal, + num_kv_heads=num_kv_heads, + ) self.backend_pref = None self.softmax_scale = softmax_scale @@ -104,12 +113,12 @@ def forward( return out def _run_local_attention(self, query, key, value, attn_metadata): - if self.backend_pref == "flash_attn" and query.dtype == torch.float32: - logger.warning( - "Flash Attention does not support float32. Overriding user config " + if query.dtype == torch.float32: + logger.warning_once( + f"Only SDPA supports float32. Overriding user config {type(self.attention)} " f"attention_backend='{self.backend_pref}' to 'sdpa' for dtype={query.dtype}." ) - self.backend_pref = "sdpa" + return self.sdpa_fallback.forward(query, key, value, attn_metadata) # Fallback to standard attention return self.attention.forward(query, key, value, attn_metadata)