diff --git a/vllm/model_executor/layers/kda.py b/vllm/model_executor/layers/kda.py index 308bc8be1dec..26458f2e3c4d 100644 --- a/vllm/model_executor/layers/kda.py +++ b/vllm/model_executor/layers/kda.py @@ -259,7 +259,7 @@ def forward( hidden_states: torch.Tensor, positions: torch.Tensor, output: torch.Tensor, - ) -> torch.Tensor: + ) -> None: num_tokens = hidden_states.size(0) q = self.q_proj(hidden_states)[0] k = self.k_proj(hidden_states)[0] @@ -291,8 +291,7 @@ def forward( ) core_attn_out = self.o_norm(core_attn_out, g2) core_attn_out = rearrange(core_attn_out, "1 n h d -> n (h d)") - - return self.o_proj(core_attn_out)[0] + output[:] = self.o_proj(core_attn_out)[0] def _forward( self,