Skip to content

Commit 249feca

Browse files
authored
[BugFix] Revert skip capture (#5023)
* Revert "[BugFix][Metax] Fix metax compile issue in get_block_shape_and_split_kv_block (#5000)" This reverts commit 05da8e3. * Revert "skip DtoH capture (#4988)" This reverts commit 5b24013.
1 parent 51b1f13 commit 249feca

File tree

1 file changed

+6
-15
lines changed

1 file changed

+6
-15
lines changed

custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
#include "helper.h"
1616
#include "paddle/extension.h"
1717
#ifndef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
18-
#include "paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h"
1918
#include "paddle/phi/core/memory/memcpy.h"
2019
#endif
2120
#include "utils.cuh"
@@ -288,13 +287,9 @@ void GetBlockShapeAndSplitKVBlock(
288287
seq_lens_encoder.data<int>(),
289288
max_len_tensor_gpu.data<int>(),
290289
bsz);
291-
// Note (sunxin): Skip capturing the DtoH copy (it's time-consuming); CPU data
292-
// is only for branching in attention.
293-
#ifndef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
294-
if (!phi::backends::gpu::IsCUDAGraphCapturing())
295-
#endif
296-
max_len_tensor_cpu.copy_(
297-
max_len_tensor_gpu, max_len_tensor_cpu.place(), false);
290+
291+
max_len_tensor_cpu.copy_(
292+
max_len_tensor_gpu, max_len_tensor_cpu.place(), false);
298293

299294
auto max_len_cpu_ptr = max_len_tensor_cpu.data<int>();
300295
int max_len_this_time = max_len_cpu_ptr[0];
@@ -403,13 +398,9 @@ void GetBlockShapeAndSplitKVBlock(
403398
bsz,
404399
decoder_block_shape_q,
405400
group_size);
406-
// Note (sunxin): Skip capturing the DtoH copy (it's time-consuming); CPU
407-
// data is only for branching in attention.
408-
#ifndef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
409-
if (!phi::backends::gpu::IsCUDAGraphCapturing())
410-
#endif
411-
decoder_num_blocks_cpu.copy_(
412-
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
401+
402+
decoder_num_blocks_cpu.copy_(
403+
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
413404
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
414405
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
415406
}

0 commit comments

Comments
 (0)