diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 7ab4352984e..43f3387d1df 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -1535,10 +1535,11 @@ std::tuple mha_bwd( // Will be zero'ed out in the backward preprocess kernel at::Tensor dq_semaphore = torch::empty({(seqlen_q + kBlockM - 1) / kBlockM, batch_size, num_heads}, opts.dtype(torch::kInt32)); params.dq_semaphore = dq_semaphore.data_ptr(); + at::Tensor dk_semaphore, dv_semaphore; if (num_heads_k != num_heads && params.deterministic) { // TODO: maybe also zero'ed out dk_semaphore and dv_semaphore in the backward preprocess kernel - at::Tensor dk_semaphore = torch::zeros({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32)); - at::Tensor dv_semaphore = torch::zeros({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32)); + dk_semaphore = torch::zeros({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32)); + dv_semaphore = torch::zeros({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32)); params.dk_semaphore = dk_semaphore.data_ptr(); params.dv_semaphore = dv_semaphore.data_ptr(); } diff --git a/hopper/flash_api_stable.cpp b/hopper/flash_api_stable.cpp index 5ae58bdd129..9759af86e08 100644 --- a/hopper/flash_api_stable.cpp +++ b/hopper/flash_api_stable.cpp @@ -1610,10 +1610,11 @@ std::tuple mha_bwd( // Will be zero'ed out in the backward preprocess kernel Tensor dq_semaphore = torch::stable::new_empty(q, {(seqlen_q + kBlockM - 1) / kBlockM, batch_size, num_heads}, std::make_optional(torch::headeronly::ScalarType::Int)); params.dq_semaphore = static_cast(dq_semaphore.data_ptr()); + Tensor dk_semaphore, dv_semaphore; if (num_heads_k != num_heads && params.deterministic) { // TODO: maybe also zero'ed out dk_semaphore and dv_semaphore in the backward preprocess kernel - Tensor dk_semaphore = torch::stable::new_zeros(q, {(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, std::make_optional(torch::headeronly::ScalarType::Int)); - Tensor dv_semaphore = torch::stable::new_zeros(q, {(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, std::make_optional(torch::headeronly::ScalarType::Int)); + dk_semaphore = torch::stable::new_zeros(q, {(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, std::make_optional(torch::headeronly::ScalarType::Int)); + dv_semaphore = torch::stable::new_zeros(q, {(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, std::make_optional(torch::headeronly::ScalarType::Int)); params.dk_semaphore = static_cast(dk_semaphore.data_ptr()); params.dv_semaphore = static_cast(dv_semaphore.data_ptr()); }