Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#ifndef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
#include "paddle/phi/core/memory/memcpy.h"
#endif
#include "paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h"
#include "utils.cuh"

template <int THREADBLOCK_SIZE>
Expand Down Expand Up @@ -287,9 +288,12 @@ void GetBlockShapeAndSplitKVBlock(
seq_lens_encoder.data<int>(),
max_len_tensor_gpu.data<int>(),
bsz);

max_len_tensor_cpu.copy_(
max_len_tensor_gpu, max_len_tensor_cpu.place(), false);
// Note (sunxin): Skip capturing the DtoH copy (it's time-consuming); CPU data
// is only for branching in attention.
if (!phi::backends::gpu::IsCUDAGraphCapturing()) {
max_len_tensor_cpu.copy_(
max_len_tensor_gpu, max_len_tensor_cpu.place(), false);
}
Comment on lines +291 to +296
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个工作会生效吗

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个工作会生效吗

before:

image

after:

image


auto max_len_cpu_ptr = max_len_tensor_cpu.data<int>();
int max_len_this_time = max_len_cpu_ptr[0];
Expand Down Expand Up @@ -398,9 +402,12 @@ void GetBlockShapeAndSplitKVBlock(
bsz,
decoder_block_shape_q,
group_size);

decoder_num_blocks_cpu.copy_(
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
// Note (sunxin): Skip capturing the DtoH copy (it's time-consuming); CPU
// data is only for branching in attention.
if (!phi::backends::gpu::IsCUDAGraphCapturing()) {
decoder_num_blocks_cpu.copy_(
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
}
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
}
Expand Down
Loading