diff --git a/vllm_omni/diffusion/attention/backends/cudnn_attn.py b/vllm_omni/diffusion/attention/backends/cudnn_attn.py index f27fe18706f..357a2c0fd06 100644 --- a/vllm_omni/diffusion/attention/backends/cudnn_attn.py +++ b/vllm_omni/diffusion/attention/backends/cudnn_attn.py @@ -51,6 +51,7 @@ def __init__( ) -> None: self.causal = causal self.softmax_scale = softmax_scale + self.requires_gqa = num_heads != num_kv_heads def forward_cuda( self, @@ -84,6 +85,7 @@ def forward_cuda( dropout_p=0.0, is_causal=self.causal, scale=self.softmax_scale, + enable_gqa=self.requires_gqa, ) except RuntimeError as e: if "No available kernel" not in str(e): @@ -100,5 +102,6 @@ def forward_cuda( dropout_p=0.0, is_causal=self.causal, scale=self.softmax_scale, + enable_gqa=self.requires_gqa, ) return output.permute(0, 2, 1, 3)