diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 8ffd0d0baf9..adb53fdab6b 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -1529,9 +1529,9 @@ std::tuple(); if (num_heads_k != num_heads && params.deterministic) { - // TODO: do we need to zero them out? - at::Tensor dk_semaphore = torch::empty({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32)); - at::Tensor dv_semaphore = torch::empty({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32)); + // TODO: maybe also zero'ed out dk_semaphore and dv_semaphore in the backward preprocess kernel + at::Tensor dk_semaphore = torch::zeros({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32)); + at::Tensor dv_semaphore = torch::zeros({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32)); params.dk_semaphore = dk_semaphore.data_ptr(); params.dv_semaphore = dv_semaphore.data_ptr(); }