diff --git a/vllm/attention/layers/cross_attention.py b/vllm/attention/layers/cross_attention.py index cfd203bdd37b..f58c9d541775 100644 --- a/vllm/attention/layers/cross_attention.py +++ b/vllm/attention/layers/cross_attention.py @@ -149,16 +149,20 @@ def __init__( kv_cache_dtype = "auto" block_size = 16 - underlying_attn_backend = get_attn_backend( - head_size, dtype, kv_cache_dtype, block_size - ) - attn_backend = create_cross_attention_backend(underlying_attn_backend) - if attn_type is not None: assert attn_type == AttentionType.ENCODER_DECODER, ( "CrossAttention only supports AttentionType.ENCODER_DECODER" ) + underlying_attn_backend = get_attn_backend( + head_size, + dtype, + kv_cache_dtype, + block_size, + attn_type=AttentionType.ENCODER_DECODER, + ) + attn_backend = create_cross_attention_backend(underlying_attn_backend) + super().__init__( num_heads=num_heads, head_size=head_size,