diff --git a/backends/metax_gpu/kernels/impl/metax_weight_quantize_kernel_impl.h b/backends/metax_gpu/kernels/impl/metax_weight_quantize_kernel_impl.h index b305ec96a30..9aedba871c5 100644 --- a/backends/metax_gpu/kernels/impl/metax_weight_quantize_kernel_impl.h +++ b/backends/metax_gpu/kernels/impl/metax_weight_quantize_kernel_impl.h @@ -25,7 +25,7 @@ #include "paddle/phi/kernels/funcs/math_function.h" namespace phi { - +template void show_2d_cpu_tensor(const DenseTensor& tensor, const int64_t row_num = 3, const int64_t col_num = 3) { @@ -33,18 +33,18 @@ void show_2d_cpu_tensor(const DenseTensor& tensor, const int64_t cols = tensor.dims()[1]; printf("\nTensor shape = [%d, %d]\n", rows, cols); - const int8_t* cpu_ptr = tensor.data(); + const DataType* cpu_ptr = tensor.data(); for (int r = 0; r < row_num; r++) { for (int c = 0; c < col_num; c++) { - int8_t val = *(cpu_ptr + r * cols + c); - printf("%d ", val); + DataType val = *(cpu_ptr + r * cols + c); + printf("%#x ", val); } printf("\n"); } printf("\n\n"); } - +template void show_2d_gpu_tensor(const CustomContext& dev_ctx, const DenseTensor& tensor, const int64_t row_num = 3, @@ -58,18 +58,39 @@ void show_2d_gpu_tensor(const CustomContext& dev_ctx, const int64_t cols = cpu_tensor.dims()[1]; printf("\nTensor shape = [%d, %d]\n", rows, cols); - const int8_t* cpu_ptr = cpu_tensor.data(); + const DataType* cpu_ptr = cpu_tensor.data(); for (int r = 0; r < row_num; r++) { for (int c = 0; c < col_num; c++) { - int8_t val = *(cpu_ptr + r * cols + c); - printf("%d ", val); + DataType val = *(cpu_ptr + r * cols + c); + printf("%#x ", val); } printf("\n"); } printf("\n\n"); } +template +void show_1d_gpu_tensor(const CustomContext& dev_ctx, + const DenseTensor& tensor, + const int64_t num = 3) { + phi::CPUPlace cpu_place; + + DenseTensor cpu_tensor; + phi::Copy(dev_ctx, tensor, cpu_place, true, &cpu_tensor); + + const int64_t nums = cpu_tensor.numel(); + printf("\nTensor shape = [%d]\n", nums); + + const DataType* cpu_ptr = cpu_tensor.data(); + + for (int n = 0; n < num; n++) { + DataType val = *(cpu_ptr + n); + printf("%#x ", val); + } + printf("\n\n"); +} + void cpu_2d_tensor_transpose(const DenseTensor& input_data, DenseTensor* transposed_data) { const int64_t input_data_rows = input_data.dims()[0]; diff --git a/backends/metax_gpu/kernels/metax_kernel/weight_quantize_kernel_register.cu b/backends/metax_gpu/kernels/metax_kernel/weight_quantize_kernel_register.cu index 8d72ed2138e..efc18693e21 100644 --- a/backends/metax_gpu/kernels/metax_kernel/weight_quantize_kernel_register.cu +++ b/backends/metax_gpu/kernels/metax_kernel/weight_quantize_kernel_register.cu @@ -116,7 +116,7 @@ void WeightQuantizeKernel(const Context& dev_ctx, dev_ctx.template Alloc(scale); weight_quant_gpu(dev_ctx, x.data(), - out->data(), + quanted_x.data(), scale->data(), weight_shape, arch, @@ -141,7 +141,13 @@ void WeightQuantizeKernel(const Context& dev_ctx, // arch, // algo); #endif - MetaxQuantizedWeightLayoutTrans(dev_ctx, algo, weight_shape, out); + quanted_x.Resize({m / 2, n}); + + std::vector axis = {1, 0}; + funcs::Transpose trans; + trans(dev_ctx, quanted_x, out, axis); + + out->Resize({n / 2, m}); } else if (algo == "w4a8") { weight_permute_gpu_w4a8(dev_ctx, x.data(),