From d340403a5108aa7d186d799b6a5c28dbcf937394 Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Thu, 29 Jan 2026 03:24:40 +0530 Subject: [PATCH 1/2] fix to use correct attn backend Signed-off-by: Divyansh Singhvi --- vllm_omni/diffusion/attention/layer.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/vllm_omni/diffusion/attention/layer.py b/vllm_omni/diffusion/attention/layer.py index 0d3bf8b7da9..a57213ee8ac 100644 --- a/vllm_omni/diffusion/attention/layer.py +++ b/vllm_omni/diffusion/attention/layer.py @@ -12,6 +12,8 @@ from vllm.logger import init_logger from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata +from vllm_omni.diffusion.attention.backends.flash_attn import FlashAttentionImpl +from vllm_omni.diffusion.attention.backends.sdpa import SDPABackend from vllm_omni.diffusion.attention.parallel import build_parallel_attention_strategy from vllm_omni.diffusion.attention.parallel.ring import RingParallelAttention from vllm_omni.diffusion.attention.selector import get_attn_backend @@ -45,6 +47,14 @@ def __init__( causal=causal, num_kv_heads=num_kv_heads, ) + # Instantiate fallback backend for float32 support + self.sdpa_fallback = SDPABackend.get_impl_cls()( + num_heads=num_heads, + head_size=head_size, + softmax_scale=softmax_scale, + causal=causal, + num_kv_heads=num_kv_heads, + ) self.backend_pref = None self.softmax_scale = softmax_scale @@ -104,12 +114,12 @@ def forward( return out def _run_local_attention(self, query, key, value, attn_metadata): - if self.backend_pref == "flash_attn" and query.dtype == torch.float32: - logger.warning( + if isinstance(self.attention, FlashAttentionImpl) and query.dtype == torch.float32: + logger.warning_once( "Flash Attention does not support float32. Overriding user config " f"attention_backend='{self.backend_pref}' to 'sdpa' for dtype={query.dtype}." ) - self.backend_pref = "sdpa" + return self.sdpa_fallback.forward(query, key, value, attn_metadata) # Fallback to standard attention return self.attention.forward(query, key, value, attn_metadata) From bce75fba800d21165b180dc47373036775c304d6 Mon Sep 17 00:00:00 2001 From: Divyansh Singhvi Date: Thu, 29 Jan 2026 10:46:29 +0530 Subject: [PATCH 2/2] fix to always revert to SDPA backend on query type float32 Signed-off-by: Divyansh Singhvi --- vllm_omni/diffusion/attention/layer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm_omni/diffusion/attention/layer.py b/vllm_omni/diffusion/attention/layer.py index a57213ee8ac..83cfe84d7b6 100644 --- a/vllm_omni/diffusion/attention/layer.py +++ b/vllm_omni/diffusion/attention/layer.py @@ -12,7 +12,6 @@ from vllm.logger import init_logger from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata -from vllm_omni.diffusion.attention.backends.flash_attn import FlashAttentionImpl from vllm_omni.diffusion.attention.backends.sdpa import SDPABackend from vllm_omni.diffusion.attention.parallel import build_parallel_attention_strategy from vllm_omni.diffusion.attention.parallel.ring import RingParallelAttention @@ -114,9 +113,9 @@ def forward( return out def _run_local_attention(self, query, key, value, attn_metadata): - if isinstance(self.attention, FlashAttentionImpl) and query.dtype == torch.float32: + if query.dtype == torch.float32: logger.warning_once( - "Flash Attention does not support float32. Overriding user config " + f"Only SDPA supports float32. Overriding user config {type(self.attention)} " f"attention_backend='{self.backend_pref}' to 'sdpa' for dtype={query.dtype}." ) return self.sdpa_fallback.forward(query, key, value, attn_metadata)