diff --git a/aiter/ops/mha.py b/aiter/ops/mha.py index 9a611aa9ef..1668441032 100644 --- a/aiter/ops/mha.py +++ b/aiter/ops/mha.py @@ -1263,6 +1263,7 @@ def _flash_attn_forward( cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_kv: Optional[torch.Tensor] = None, sink_ptr: Optional[Tensor] = None, + out: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: batch_size, seqlen_q, nhead_q, hdim_q = q.shape @@ -1323,7 +1324,7 @@ def _validate_cu(name: str, x: Optional[torch.Tensor]): _validate_cu("cu_seqlens_kv", cu_seqlens_kv) if can_impl_fmha_v3_fwd() and seqlen_q > 128: # Prefer CK for decode cases - out, softmax_lse, S_dmask, rng_state = fmha_v3_fwd( + out_, softmax_lse, S_dmask, rng_state = fmha_v3_fwd( q, k, v, @@ -1335,7 +1336,7 @@ def _validate_cu(name: str, x: Optional[torch.Tensor]): return_lse, return_softmax, how_v3_bf16_cvt, - None, + out, bias, alibi_slopes, q_descale, @@ -1344,7 +1345,7 @@ def _validate_cu(name: str, x: Optional[torch.Tensor]): None, ) else: - out, softmax_lse, S_dmask, rng_state = mha_fwd( + out_, softmax_lse, S_dmask, rng_state = mha_fwd( q, k, v, @@ -1358,7 +1359,7 @@ def _validate_cu(name: str, x: Optional[torch.Tensor]): return_softmax, cu_seqlens_q, cu_seqlens_kv, - None, + out, bias, alibi_slopes, q_descale, @@ -1368,7 +1369,7 @@ def _validate_cu(name: str, x: Optional[torch.Tensor]): None, # custom_build_args={"md_name": md_name, "blob_gen_cmd": blob_gen_cmd}, ) - return out, softmax_lse, S_dmask, rng_state + return out_, softmax_lse, S_dmask, rng_state # @torch_compile_guard(mutates_args=[])