diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index 0d16a843586e..f7191799065f 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -11,7 +11,8 @@ import vllm_hpu_extension.kernels as kernels import vllm_hpu_extension.ops as ops from vllm_hpu_extension.flags import enabled_flags -from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache +from vllm_hpu_extension.utils import (Matmul, ModuleFusedSDPA, Softmax, + VLLMKVCache) import vllm.envs as envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, @@ -373,7 +374,9 @@ def __init__( self.block2batch_matmul = Matmul() self.k_cache = VLLMKVCache() self.v_cache = VLLMKVCache() - self.fused_scaled_dot_product_attention = kernels.fsdpa() + HPUFusedSDPA = kernels.fsdpa() + self.fused_scaled_dot_product_attention = None if HPUFusedSDPA is None \ + else ModuleFusedSDPA(HPUFusedSDPA) self.prefill_impl = 'naive' if "flex_attention" in enabled_flags(): @@ -506,16 +509,13 @@ def common_attention_args(self, block_list=None, key_cache=None, value_cache=None): - fsdpa_op = self.fused_scaled_dot_product_attention.apply \ - if self.fused_scaled_dot_product_attention is not None else None - return { 'scale': self.scale, 'matmul_qk_op': self.matmul_qk, 'matmul_av_op': self.matmul_av, 'batch2block_matmul_op': self.batch2block_matmul, 'block2batch_matmul_op': self.block2batch_matmul, - 'fsdpa_op': fsdpa_op, + 'fsdpa_op': self.fused_scaled_dot_product_attention, 'keys_fetch_func': self.k_cache.fetch_from_cache, 'values_fetch_func': self.v_cache.fetch_from_cache, 'softmax_op': self.softmax,