|
15 | 15 | #include "helper.h" |
16 | 16 | #include "paddle/extension.h" |
17 | 17 | #ifndef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU |
| 18 | +#include "paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h" |
18 | 19 | #include "paddle/phi/core/memory/memcpy.h" |
19 | 20 | #endif |
20 | | -#include "paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h" |
21 | 21 | #include "utils.cuh" |
22 | 22 |
|
23 | 23 | template <int THREADBLOCK_SIZE> |
@@ -290,10 +290,11 @@ void GetBlockShapeAndSplitKVBlock( |
290 | 290 | bsz); |
291 | 291 | // Note (sunxin): Skip capturing the DtoH copy (it's time-consuming); CPU data |
292 | 292 | // is only for branching in attention. |
293 | | - if (!phi::backends::gpu::IsCUDAGraphCapturing()) { |
| 293 | +#ifndef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU |
| 294 | + if (!phi::backends::gpu::IsCUDAGraphCapturing()) |
| 295 | +#endif |
294 | 296 | max_len_tensor_cpu.copy_( |
295 | 297 | max_len_tensor_gpu, max_len_tensor_cpu.place(), false); |
296 | | - } |
297 | 298 |
|
298 | 299 | auto max_len_cpu_ptr = max_len_tensor_cpu.data<int>(); |
299 | 300 | int max_len_this_time = max_len_cpu_ptr[0]; |
@@ -404,10 +405,11 @@ void GetBlockShapeAndSplitKVBlock( |
404 | 405 | group_size); |
405 | 406 | // Note (sunxin): Skip capturing the DtoH copy (it's time-consuming); CPU |
406 | 407 | // data is only for branching in attention. |
407 | | - if (!phi::backends::gpu::IsCUDAGraphCapturing()) { |
| 408 | +#ifndef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU |
| 409 | + if (!phi::backends::gpu::IsCUDAGraphCapturing()) |
| 410 | +#endif |
408 | 411 | decoder_num_blocks_cpu.copy_( |
409 | 412 | decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false); |
410 | | - } |
411 | 413 | PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync( |
412 | 414 | decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream)); |
413 | 415 | } |
|
0 commit comments