|
17 | 17 | #ifndef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU |
18 | 18 | #include "paddle/phi/core/memory/memcpy.h" |
19 | 19 | #endif |
| 20 | +#include "paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h" |
20 | 21 | #include "utils.cuh" |
21 | 22 |
|
22 | 23 | template <int THREADBLOCK_SIZE> |
@@ -287,9 +288,12 @@ void GetBlockShapeAndSplitKVBlock( |
287 | 288 | seq_lens_encoder.data<int>(), |
288 | 289 | max_len_tensor_gpu.data<int>(), |
289 | 290 | bsz); |
290 | | - |
291 | | - max_len_tensor_cpu.copy_( |
292 | | - max_len_tensor_gpu, max_len_tensor_cpu.place(), false); |
| 291 | + // Note (sunxin): Skip capturing the DtoH copy (it's time-consuming); CPU data |
| 292 | + // is only for branching in attention. |
| 293 | + if (!phi::backends::gpu::IsCUDAGraphCapturing()) { |
| 294 | + max_len_tensor_cpu.copy_( |
| 295 | + max_len_tensor_gpu, max_len_tensor_cpu.place(), false); |
| 296 | + } |
293 | 297 |
|
294 | 298 | auto max_len_cpu_ptr = max_len_tensor_cpu.data<int>(); |
295 | 299 | int max_len_this_time = max_len_cpu_ptr[0]; |
@@ -398,9 +402,12 @@ void GetBlockShapeAndSplitKVBlock( |
398 | 402 | bsz, |
399 | 403 | decoder_block_shape_q, |
400 | 404 | group_size); |
401 | | - |
402 | | - decoder_num_blocks_cpu.copy_( |
403 | | - decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false); |
| 405 | + // Note (sunxin): Skip capturing the DtoH copy (it's time-consuming); CPU |
| 406 | + // data is only for branching in attention. |
| 407 | + if (!phi::backends::gpu::IsCUDAGraphCapturing()) { |
| 408 | + decoder_num_blocks_cpu.copy_( |
| 409 | + decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false); |
| 410 | + } |
404 | 411 | PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync( |
405 | 412 | decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream)); |
406 | 413 | } |
|
0 commit comments