Skip to content

Commit 5b24013

Browse files
authored
skip DtoH capture (#4988)
1 parent 8329338 commit 5b24013

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#ifndef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
1818
#include "paddle/phi/core/memory/memcpy.h"
1919
#endif
20+
#include "paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h"
2021
#include "utils.cuh"
2122

2223
template <int THREADBLOCK_SIZE>
@@ -287,9 +288,12 @@ void GetBlockShapeAndSplitKVBlock(
287288
seq_lens_encoder.data<int>(),
288289
max_len_tensor_gpu.data<int>(),
289290
bsz);
290-
291-
max_len_tensor_cpu.copy_(
292-
max_len_tensor_gpu, max_len_tensor_cpu.place(), false);
291+
// Note (sunxin): Skip capturing the DtoH copy (it's time-consuming); CPU data
292+
// is only for branching in attention.
293+
if (!phi::backends::gpu::IsCUDAGraphCapturing()) {
294+
max_len_tensor_cpu.copy_(
295+
max_len_tensor_gpu, max_len_tensor_cpu.place(), false);
296+
}
293297

294298
auto max_len_cpu_ptr = max_len_tensor_cpu.data<int>();
295299
int max_len_this_time = max_len_cpu_ptr[0];
@@ -398,9 +402,12 @@ void GetBlockShapeAndSplitKVBlock(
398402
bsz,
399403
decoder_block_shape_q,
400404
group_size);
401-
402-
decoder_num_blocks_cpu.copy_(
403-
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
405+
// Note (sunxin): Skip capturing the DtoH copy (it's time-consuming); CPU
406+
// data is only for branching in attention.
407+
if (!phi::backends::gpu::IsCUDAGraphCapturing()) {
408+
decoder_num_blocks_cpu.copy_(
409+
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
410+
}
404411
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
405412
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
406413
}

0 commit comments

Comments
 (0)