diff --git a/paddle/fluid/framework/ir/xpu/weight_only_linear_xpu_pass.cc b/paddle/fluid/framework/ir/xpu/weight_only_linear_xpu_pass.cc index 08d975512e2f6d..17772aa110d207 100644 --- a/paddle/fluid/framework/ir/xpu/weight_only_linear_xpu_pass.cc +++ b/paddle/fluid/framework/ir/xpu/weight_only_linear_xpu_pass.cc @@ -222,8 +222,10 @@ void PermuteINT8WeightOnlyPass::ApplyImpl(ir::Graph* graph) const { PADDLE_ENFORCE_NOT_NULL( graph, common::errors::PreconditionNotMet("graph should not be null.")); Init(name_scope_, graph); - - ApplyPermuteINT8WeightOnly(graph); + auto version = phi::backends::xpu::get_xpu_version(-1); + if (version == phi::backends::xpu::XPUVersion::XPU2) { + ApplyPermuteINT8WeightOnly(graph); + } } } // namespace ir diff --git a/paddle/phi/backends/xpu/xpu3_op_list.cc b/paddle/phi/backends/xpu/xpu3_op_list.cc index b2fa8d3c238773..e5680bbd9c8791 100644 --- a/paddle/phi/backends/xpu/xpu3_op_list.cc +++ b/paddle/phi/backends/xpu/xpu3_op_list.cc @@ -1690,6 +1690,8 @@ XPUOpMap& get_kl3_ops() { phi::DataType::BFLOAT16})}, {"warpctc_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"warpctc", XPUKernelSet({phi::DataType::FLOAT32})}, + {"weight_only_linear", + XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::BFLOAT16})}, {"where_index", XPUKernelSet({phi::DataType::INT32, phi::DataType::BOOL, diff --git a/paddle/phi/kernels/fusion/xpu/weight_only_linear_kernel_xpu.cc b/paddle/phi/kernels/fusion/xpu/weight_only_linear_kernel_xpu.cc index ddc396b5ae36e8..98322a9dfa8a83 100644 --- a/paddle/phi/kernels/fusion/xpu/weight_only_linear_kernel_xpu.cc +++ b/paddle/phi/kernels/fusion/xpu/weight_only_linear_kernel_xpu.cc @@ -17,20 +17,20 @@ namespace phi { template -void WeightOnlyLinearKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& weight, - const paddle::optional& bias, - const DenseTensor& weight_scale, - const std::string& weight_dtype, - const int32_t arch, - const int32_t group_size, - DenseTensor* out) { +void WeightOnlyLinearXpuKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& weight, + const paddle::optional& bias, + const DenseTensor& weight_scale, + const std::string& weight_dtype, + const int32_t arch, + const int32_t group_size, + DenseTensor* out) { PADDLE_ENFORCE_EQ( weight_dtype, "int8", common::errors::Fatal( - "WeightOnlyLinearKernel xpu just support int8 weight only")); + "WeightOnlyLinearXpuKernel xpu just support int8 weight only")); phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); auto xpu_ctx = static_cast(&dev_ctx); dev_ctx.template Alloc(out); @@ -134,6 +134,6 @@ void WeightOnlyLinearKernel(const Context& dev_ctx, PD_REGISTER_KERNEL(weight_only_linear_xpu, XPU, ALL_LAYOUT, - phi::WeightOnlyLinearKernel, + phi::WeightOnlyLinearXpuKernel, phi::dtype::float16, phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/xpu/weight_only_linear_kernel.cc b/paddle/phi/kernels/xpu/weight_only_linear_kernel.cc new file mode 100644 index 00000000000000..55cc5c84327cbc --- /dev/null +++ b/paddle/phi/kernels/xpu/weight_only_linear_kernel.cc @@ -0,0 +1,141 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" +#ifdef PADDLE_WITH_XPU_XRE5 +#include "xblas/xblas_legacy_api.h" +#endif + +namespace phi { +template +void WeightOnlyLinearKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& weight, + const paddle::optional& bias, + const DenseTensor& weight_scale, + const std::string& weight_dtype, + const int32_t arch, + const int32_t group_size, + DenseTensor* out) { + using XPUType = typename XPUTypeTrait::Type; + int64_t n = weight.dims()[0]; + int64_t k = weight.dims()[1]; + int64_t m = x.numel() / k; + if (weight_dtype == "int4") { + n = n * 2; + } + out->Resize({static_cast(m), static_cast(n)}); + dev_ctx.template Alloc(out); + + DenseTensor bias_fp32; + if (bias.is_initialized() && bias.get().dtype() == phi::DataType::FLOAT16) { + bias_fp32.Resize(bias.get().dims()); + dev_ctx.template Alloc(&bias_fp32); + int r = baidu::xpu::api::cast( + dev_ctx.x_context(), + reinterpret_cast( + bias.get().data()), + bias_fp32.data(), + n); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast"); + } + auto input_x = reinterpret_cast(x.data()); + auto input_y = reinterpret_cast(out->data()); + + baidu::xpu::xblas::FcFusionTensor tensor_x{ + input_x, nullptr, m, k, k, false}; + baidu::xpu::xblas::FcFusionTensor tensor_y_const{ + input_y, nullptr, m, n, n, false}; + baidu::xpu::xblas::FcFusionTensor tensor_y{ + input_y, nullptr, m, n, n, false}; + baidu::xpu::xblas::FcFusionEpilogue epilogue{ + api::Activation_t::LINEAR, + bias.is_initialized() ? (bias.get().dtype() == phi::DataType::FLOAT16 + ? bias_fp32.data() + : bias.get().data()) + : nullptr, + nullptr, + weight_scale.dims().size() != 0 ? weight_scale.data() : nullptr, + 0, + 1, + nullptr}; + + if (weight_dtype == "int8") { + // using TGEMM=int8_wo_t; + using TGEMM = float; + baidu::xpu::xblas::FcFusionDesc desc{1.0f, 0.0f}; + baidu::xpu::xblas::FcFusionTensor tensor_w{ + reinterpret_cast(weight.data()), + nullptr, + n, + k, + k, + true}; + int r1 = baidu::xpu::xblas::fc_fusion(dev_ctx.x_context(), + tensor_x, + tensor_w, + tensor_y_const, + tensor_y, + desc, + epilogue); + PD_CHECK(r1 == 0, "xblas::fc_fusion failed"); + } else if (weight_dtype == "int4") { + // baidu::xpu::xblas::FcFusionDesc + // desc{1.0f, 0.0f}; + // baidu::xpu::xblas::FcFusionTensor tensor_w{ + // reinterpret_cast(weight.data()), + // nullptr, + // n, + // k, + // k, + // true}; + // int r1 = baidu::xpu::xblas::fc_fusion(dev_ctx.x_context(), + // tensor_x, + // tensor_w, + // tensor_y_const, + // tensor_y, + // desc, + // epilogue); + // PD_CHECK(r1 == 0, "xblas::fc_fusion failed"); + PD_THROW("unsupported weight_dtype=int4"); + } else { + PD_THROW("unsupported weight_dtype: ", weight_dtype); + } +} +} // namespace phi + +PD_REGISTER_KERNEL(weight_only_linear, + XPU, + ALL_LAYOUT, + phi::WeightOnlyLinearKernel, + phi::dtype::float16, + phi::dtype::bfloat16) {}