From 2ff34689db3eb905f7b87fe2390bac512d134e53 Mon Sep 17 00:00:00 2001 From: 0x45f Date: Tue, 6 Dec 2022 11:20:37 +0000 Subject: [PATCH] Fix accuracy fp16 kernel return fp32 tensor error --- paddle/phi/kernels/gpu/accuracy_kernel.cu | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/paddle/phi/kernels/gpu/accuracy_kernel.cu b/paddle/phi/kernels/gpu/accuracy_kernel.cu index ef3e5b9af2408..8a4aa2a6397c9 100644 --- a/paddle/phi/kernels/gpu/accuracy_kernel.cu +++ b/paddle/phi/kernels/gpu/accuracy_kernel.cu @@ -26,13 +26,13 @@ namespace phi { using phi::PADDLE_CUDA_NUM_THREADS; -template +template __global__ void AccuracyCudaKernel(const int N, const int D, const int64_t* Xdata, const int64_t* labeldata, int* correct_data, - float* accuracy, + T* accuracy, int* total_data) { int count = 0; __shared__ int total[BlockSize]; @@ -64,7 +64,7 @@ __global__ void AccuracyCudaKernel(const int N, #endif if (threadIdx.x == 0) { *correct_data = result; - *accuracy = static_cast(result) / static_cast(N); + *accuracy = static_cast(result) / static_cast(N); *total_data = N; } } @@ -84,18 +84,18 @@ void AccuracyRawKernel(const Context& dev_ctx, int* correct_data = dev_ctx.template Alloc(correct); int* total_data = dev_ctx.template Alloc(total); - float* accuracy_data = dev_ctx.template Alloc(accuracy); + T* accuracy_data = dev_ctx.template Alloc(accuracy); int num_samples = static_cast(inference.dims()[0]); size_t infer_width = inference.dims()[1]; auto stream = dev_ctx.stream(); - phi::backends::gpu::GpuMemsetAsync(accuracy_data, 0, sizeof(float), stream); + phi::backends::gpu::GpuMemsetAsync(accuracy_data, 0, sizeof(T), stream); if (num_samples == 0) { return; } - AccuracyCudaKernel + AccuracyCudaKernel <<<1, PADDLE_CUDA_NUM_THREADS, 0, stream>>>(num_samples, infer_width, indices_data,