Skip to content

Commit 05da8e3

Browse files
authored
[BugFix][Metax] Fix metax compile issue in get_block_shape_and_split_kv_block (#5000)
* fix metax compile * fix
1 parent 88da9d9 commit 05da8e3

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
#include "helper.h"
1616
#include "paddle/extension.h"
1717
#ifndef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
18+
#include "paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h"
1819
#include "paddle/phi/core/memory/memcpy.h"
1920
#endif
20-
#include "paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h"
2121
#include "utils.cuh"
2222

2323
template <int THREADBLOCK_SIZE>
@@ -290,10 +290,11 @@ void GetBlockShapeAndSplitKVBlock(
290290
bsz);
291291
// Note (sunxin): Skip capturing the DtoH copy (it's time-consuming); CPU data
292292
// is only for branching in attention.
293-
if (!phi::backends::gpu::IsCUDAGraphCapturing()) {
293+
#ifndef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
294+
if (!phi::backends::gpu::IsCUDAGraphCapturing())
295+
#endif
294296
max_len_tensor_cpu.copy_(
295297
max_len_tensor_gpu, max_len_tensor_cpu.place(), false);
296-
}
297298

298299
auto max_len_cpu_ptr = max_len_tensor_cpu.data<int>();
299300
int max_len_this_time = max_len_cpu_ptr[0];
@@ -404,10 +405,11 @@ void GetBlockShapeAndSplitKVBlock(
404405
group_size);
405406
// Note (sunxin): Skip capturing the DtoH copy (it's time-consuming); CPU
406407
// data is only for branching in attention.
407-
if (!phi::backends::gpu::IsCUDAGraphCapturing()) {
408+
#ifndef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
409+
if (!phi::backends::gpu::IsCUDAGraphCapturing())
410+
#endif
408411
decoder_num_blocks_cpu.copy_(
409412
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
410-
}
411413
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
412414
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
413415
}

0 commit comments

Comments
 (0)