diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index 956c5efd7f8d..52ec91e30183 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -11,7 +11,8 @@ import vllm_hpu_extension.kernels as kernels import vllm_hpu_extension.ops as ops from vllm_hpu_extension.flags import enabled_flags -from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache +from vllm_hpu_extension.utils import (Matmul, ModuleFusedSDPA, Softmax, + VLLMKVCache) import vllm.envs as envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, @@ -133,7 +134,9 @@ def __init__( self.block2batch_matmul = Matmul() self.k_cache = VLLMKVCache() self.v_cache = VLLMKVCache() - self.fused_scaled_dot_product_attention = kernels.fsdpa() + HPUFusedSDPA = kernels.fsdpa() + self.fused_scaled_dot_product_attention = None if HPUFusedSDPA is None \ + else ModuleFusedSDPA(HPUFusedSDPA) self.prefill_impl = 'naive' if "flex_attention" in enabled_flags(): @@ -272,7 +275,7 @@ def common_attention_args(self, block_list=None, key_cache=None, value_cache=None): - fsdpa_op = self.fused_scaled_dot_product_attention.apply \ + fsdpa_op = self.fused_scaled_dot_product_attention \ if self.fused_scaled_dot_product_attention is not None else None return {