Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions vllm_omni/diffusion/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from vllm.logger import init_logger

from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata
from vllm_omni.diffusion.attention.backends.sdpa import SDPABackend
from vllm_omni.diffusion.attention.parallel import build_parallel_attention_strategy
from vllm_omni.diffusion.attention.parallel.ring import RingParallelAttention
from vllm_omni.diffusion.attention.selector import get_attn_backend
Expand Down Expand Up @@ -45,6 +46,14 @@ def __init__(
causal=causal,
num_kv_heads=num_kv_heads,
)
# Instantiate fallback backend for float32 support
self.sdpa_fallback = SDPABackend.get_impl_cls()(
num_heads=num_heads,
head_size=head_size,
softmax_scale=softmax_scale,
causal=causal,
num_kv_heads=num_kv_heads,
)
self.backend_pref = None

self.softmax_scale = softmax_scale
Expand Down Expand Up @@ -104,12 +113,12 @@ def forward(
return out

def _run_local_attention(self, query, key, value, attn_metadata):
if self.backend_pref == "flash_attn" and query.dtype == torch.float32:
logger.warning(
"Flash Attention does not support float32. Overriding user config "
if query.dtype == torch.float32:
logger.warning_once(
f"Only SDPA supports float32. Overriding user config {type(self.attention)} "
f"attention_backend='{self.backend_pref}' to 'sdpa' for dtype={query.dtype}."
)
self.backend_pref = "sdpa"
return self.sdpa_fallback.forward(query, key, value, attn_metadata)

# Fallback to standard attention
return self.attention.forward(query, key, value, attn_metadata)
Expand Down